diff --git a/chancloser.go b/chancloser.go index 8b007807..3536f1ce 100644 --- a/chancloser.go +++ b/chancloser.go @@ -28,10 +28,10 @@ var ( // a message while it is in an unknown state. ErrInvalidState = fmt.Errorf("invalid state") - // errUpfrontShutdownScriptMismatch is returned when our peer sends us a - // shutdown message with a script that does not match the upfront shutdown - // script previously set. - errUpfrontShutdownScriptMismatch = fmt.Errorf("peer's shutdown " + + // errUpfrontShutdownScriptMismatch is returned when a peer or end user + // provides a script to cooperatively close out to which does not match + // the upfront shutdown script previously set for that party. + errUpfrontShutdownScriptMismatch = fmt.Errorf("shutdown " + "script does not match upfront shutdown script") ) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index aa040f4a..b78f43b1 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -105,6 +105,9 @@ type ChanClose struct { // process for the cooperative closure transaction kicks off. TargetFeePerKw chainfee.SatPerKWeight + // DeliveryScript is an optional delivery script to pay funds out to. + DeliveryScript lnwire.DeliveryAddress + // Updates is used by request creator to receive the notifications about // execution of the close channel request. Updates chan interface{} @@ -1365,12 +1368,13 @@ func (s *Switch) teardownCircuit(pkt *htlcPacket) error { } // CloseLink creates and sends the close channel command to the target link -// directing the specified closure type. If the closure type if CloseRegular, -// then the last parameter should be the ideal fee-per-kw that will be used as -// a starting point for close negotiation. -func (s *Switch) CloseLink(chanPoint *wire.OutPoint, closeType ChannelCloseType, - targetFeePerKw chainfee.SatPerKWeight) (chan interface{}, - chan error) { +// directing the specified closure type. If the closure type is CloseRegular, +// targetFeePerKw parameter should be the ideal fee-per-kw that will be used as +// a starting point for close negotiation. The deliveryScript parameter is an +// optional parameter which sets a user specified script to close out to. +func (s *Switch) CloseLink(chanPoint *wire.OutPoint, + closeType ChannelCloseType, targetFeePerKw chainfee.SatPerKWeight, + deliveryScript lnwire.DeliveryAddress) (chan interface{}, chan error) { // TODO(roasbeef) abstract out the close updates. updateChan := make(chan interface{}, 2) @@ -1381,6 +1385,7 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, closeType ChannelCloseType, ChanPoint: chanPoint, Updates: updateChan, TargetFeePerKw: targetFeePerKw, + DeliveryScript: deliveryScript, Err: errChan, } diff --git a/peer.go b/peer.go index 9996fd52..93f079ef 100644 --- a/peer.go +++ b/peer.go @@ -2122,6 +2122,37 @@ func (p *peer) fetchActiveChanCloser(chanID lnwire.ChannelID) (*channelCloser, e return chanCloser, nil } +// chooseDeliveryScript takes two optionally set shutdown scripts and returns +// a suitable script to close out to. This may be nil if neither script is +// set. If both scripts are set, this function will error if they do not match. +func chooseDeliveryScript(upfront, + requested lnwire.DeliveryAddress) (lnwire.DeliveryAddress, error) { + + // If no upfront upfront shutdown script was provided, return the user + // requested address (which may be nil). + if len(upfront) == 0 { + return requested, nil + } + + // If an upfront shutdown script was provided, and the user did not request + // a custom shutdown script, return the upfront address. + if len(requested) == 0 { + return upfront, nil + } + + // If both an upfront shutdown script and a custom close script were + // provided, error if the user provided shutdown script does not match + // the upfront shutdown script (because closing out to a different script + // would violate upfront shutdown). + if !bytes.Equal(upfront, requested) { + return nil, errUpfrontShutdownScriptMismatch + } + + // The user requested script matches the upfront shutdown script, so we + // can return it without error. + return upfront, nil +} + // handleLocalCloseReq kicks-off the workflow to execute a cooperative or // forced unilateral closure of the channel initiated by a local subsystem. func (p *peer) handleLocalCloseReq(req *htlcswitch.ChanClose) { @@ -2144,13 +2175,26 @@ func (p *peer) handleLocalCloseReq(req *htlcswitch.ChanClose) { // out this channel on-chain, so we execute the cooperative channel // closure workflow. case htlcswitch.CloseRegular: - // First, we'll fetch a delivery script that we'll use to send the - // funds to in the case of a successful negotiation. If an upfront - // shutdown script was set, we will use it. Otherwise, we get a fresh - // delivery script. - deliveryScript := channel.LocalUpfrontShutdownScript() + // First, we'll choose a delivery address that we'll use to send the + // funds to in the case of a successful negotiation. + + // An upfront shutdown and user provided script are both optional, + // but must be equal if both set (because we cannot serve a request + // to close out to a script which violates upfront shutdown). Get the + // appropriate address to close out to (which may be nil if neither + // are set) and error if they are both set and do not match. + deliveryScript, err := chooseDeliveryScript( + channel.LocalUpfrontShutdownScript(), req.DeliveryScript, + ) + if err != nil { + peerLog.Errorf("cannot close channel %v: %v", req.ChanPoint, err) + req.Err <- err + return + } + + // If neither an upfront address or a user set address was + // provided, generate a fresh script. if len(deliveryScript) == 0 { - var err error deliveryScript, err = p.genDeliveryScript() if err != nil { peerLog.Errorf(err.Error()) diff --git a/peer_test.go b/peer_test.go index 424405fc..14cb1979 100644 --- a/peer_test.go +++ b/peer_test.go @@ -3,17 +3,32 @@ package lnd import ( + "bytes" "testing" "time" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/lightningnetwork/lnd/lnwire" ) +var ( + // p2SHAddress is a valid pay to script hash address. + p2SHAddress = "2NBFNJTktNa7GZusGbDbGKRZTxdK9VVez3n" + + // p2wshAddress is a valid pay to witness script hash address. + p2wshAddress = "bc1qrp33g0q5c5txsp9arysrx4k6zdkfs4nce4xj0gdcccefvpysxf3qccfmv3" + + // timeout is a timeout value to use for tests which need ot wait for + // a return value on a channel. + timeout = time.Second * 5 +) + // TestPeerChannelClosureAcceptFeeResponder tests the shutdown responder's // behavior if we can agree on the fee immediately. func TestPeerChannelClosureAcceptFeeResponder(t *testing.T) { @@ -25,7 +40,8 @@ func TestPeerChannelClosureAcceptFeeResponder(t *testing.T) { broadcastTxChan := make(chan *wire.MsgTx) responder, responderChan, initiatorChan, cleanUp, err := createTestPeer( - notifier, broadcastTxChan) + notifier, broadcastTxChan, noUpdate, + ) if err != nil { t.Fatalf("unable to create test channels: %v", err) } @@ -45,7 +61,7 @@ func TestPeerChannelClosureAcceptFeeResponder(t *testing.T) { select { case outMsg := <-responder.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive shutdown message") } @@ -61,7 +77,7 @@ func TestPeerChannelClosureAcceptFeeResponder(t *testing.T) { select { case outMsg := <-responder.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive ClosingSigned message") } @@ -94,7 +110,7 @@ func TestPeerChannelClosureAcceptFeeResponder(t *testing.T) { // the closing transaction. select { case <-broadcastTxChan: - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("closing tx not broadcast") } @@ -113,7 +129,8 @@ func TestPeerChannelClosureAcceptFeeInitiator(t *testing.T) { broadcastTxChan := make(chan *wire.MsgTx) initiator, initiatorChan, responderChan, cleanUp, err := createTestPeer( - notifier, broadcastTxChan) + notifier, broadcastTxChan, noUpdate, + ) if err != nil { t.Fatalf("unable to create test channels: %v", err) } @@ -136,7 +153,7 @@ func TestPeerChannelClosureAcceptFeeInitiator(t *testing.T) { select { case outMsg := <-initiator.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive shutdown request") } @@ -184,7 +201,7 @@ func TestPeerChannelClosureAcceptFeeInitiator(t *testing.T) { select { case outMsg := <-initiator.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive closing signed message") } @@ -202,7 +219,7 @@ func TestPeerChannelClosureAcceptFeeInitiator(t *testing.T) { // the closing transaction. select { case <-broadcastTxChan: - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("closing tx not broadcast") } @@ -222,7 +239,7 @@ func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) { broadcastTxChan := make(chan *wire.MsgTx) responder, responderChan, initiatorChan, cleanUp, err := createTestPeer( - notifier, broadcastTxChan, + notifier, broadcastTxChan, noUpdate, ) if err != nil { t.Fatalf("unable to create test channels: %v", err) @@ -244,7 +261,7 @@ func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) { select { case outMsg := <-responder.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive shutdown message") } @@ -260,7 +277,7 @@ func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) { select { case outMsg := <-responder.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive closing signed message") } @@ -296,7 +313,7 @@ func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) { select { case outMsg := <-responder.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive closing signed message") } @@ -338,7 +355,7 @@ func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) { select { case outMsg := <-responder.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive closing signed message") } @@ -382,7 +399,7 @@ func TestPeerChannelClosureFeeNegotiationsResponder(t *testing.T) { // the closing transaction. select { case <-broadcastTxChan: - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("closing tx not broadcast") } @@ -402,7 +419,8 @@ func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) { broadcastTxChan := make(chan *wire.MsgTx) initiator, initiatorChan, responderChan, cleanUp, err := createTestPeer( - notifier, broadcastTxChan) + notifier, broadcastTxChan, noUpdate, + ) if err != nil { t.Fatalf("unable to create test channels: %v", err) } @@ -426,7 +444,7 @@ func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) { select { case outMsg := <-initiator.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive shutdown request") } @@ -477,7 +495,7 @@ func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) { select { case outMsg := <-initiator.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive closing signed") } closingSignedMsg, ok := msg.(*lnwire.ClosingSigned) @@ -495,7 +513,7 @@ func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) { select { case outMsg := <-initiator.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive closing signed") } closingSignedMsg, ok = msg.(*lnwire.ClosingSigned) @@ -541,7 +559,7 @@ func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) { select { case outMsg := <-initiator.outgoingQueue: msg = outMsg.msg - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("did not receive closing signed") } @@ -584,7 +602,233 @@ func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) { // Wait for closing tx to be broadcasted. select { case <-broadcastTxChan: - case <-time.After(time.Second * 5): + case <-time.After(timeout): t.Fatalf("closing tx not broadcast") } } + +// TestChooseDeliveryScript tests that chooseDeliveryScript correctly errors +// when upfront and user set scripts that do not match are provided, allows +// matching values and returns appropriate values in the case where one or none +// are set. +func TestChooseDeliveryScript(t *testing.T) { + // generate non-zero scripts for testing. + script1 := genScript(t, p2SHAddress) + script2 := genScript(t, p2wshAddress) + + tests := []struct { + name string + userScript lnwire.DeliveryAddress + shutdownScript lnwire.DeliveryAddress + expectedScript lnwire.DeliveryAddress + expectedError error + }{ + { + name: "Neither set", + userScript: nil, + shutdownScript: nil, + expectedScript: nil, + expectedError: nil, + }, + { + name: "Both set and equal", + userScript: script1, + shutdownScript: script1, + expectedScript: script1, + expectedError: nil, + }, + { + name: "Both set and not equal", + userScript: script1, + shutdownScript: script2, + expectedScript: nil, + expectedError: errUpfrontShutdownScriptMismatch, + }, + { + name: "Only upfront script", + userScript: nil, + shutdownScript: script1, + expectedScript: script1, + expectedError: nil, + }, + { + name: "Only user script", + userScript: script2, + shutdownScript: nil, + expectedScript: script2, + expectedError: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + script, err := chooseDeliveryScript( + test.shutdownScript, test.userScript, + ) + if err != test.expectedError { + t.Fatalf("Expected: %v, got: %v", test.expectedError, err) + } + + if !bytes.Equal(script, test.expectedScript) { + t.Fatalf("Expected: %x, got: %x", test.expectedScript, script) + } + }) + } +} + +// TestCustomShutdownScript tests that the delivery script of a shutdown +// message can be set to a specified address. It checks that setting a close +// script fails for channels which have an upfront shutdown script already set. +func TestCustomShutdownScript(t *testing.T) { + script := genScript(t, p2SHAddress) + + // setShutdown is a function which sets the upfront shutdown address for + // the local channel. + setShutdown := func(a, b *channeldb.OpenChannel) { + a.LocalShutdownScript = script + b.RemoteShutdownScript = script + } + + tests := []struct { + name string + + // update is a function used to set values on the channel set up for the + // test. It is used to set values for upfront shutdown addresses. + update func(a, b *channeldb.OpenChannel) + + // userCloseScript is the address specified by the user. + userCloseScript lnwire.DeliveryAddress + + // expectedScript is the address we expect to be set on the shutdown + // message. + expectedScript lnwire.DeliveryAddress + + // expectedError is the error we expect, if any. + expectedError error + }{ + { + name: "User set script", + update: noUpdate, + userCloseScript: script, + expectedScript: script, + }, + { + name: "No user set script", + update: noUpdate, + }, + { + name: "Shutdown set, no user script", + update: setShutdown, + expectedScript: script, + }, + { + name: "Shutdown set, user script matches", + update: setShutdown, + userCloseScript: script, + expectedScript: script, + }, + { + name: "Shutdown set, user script different", + update: setShutdown, + userCloseScript: []byte("different addr"), + expectedError: errUpfrontShutdownScriptMismatch, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + notifier := &mockNotfier{ + confChannel: make(chan *chainntnfs.TxConfirmation), + } + broadcastTxChan := make(chan *wire.MsgTx) + + // Open a channel. + initiator, initiatorChan, _, cleanUp, err := createTestPeer( + notifier, broadcastTxChan, test.update, + ) + if err != nil { + t.Fatalf("unable to create test channels: %v", err) + } + defer cleanUp() + + // Request initiator to cooperatively close the channel, with + // a specified delivery address. + updateChan := make(chan interface{}, 1) + errChan := make(chan error, 1) + chanPoint := initiatorChan.ChannelPoint() + closeCommand := htlcswitch.ChanClose{ + CloseType: htlcswitch.CloseRegular, + ChanPoint: chanPoint, + Updates: updateChan, + TargetFeePerKw: 12500, + DeliveryScript: test.userCloseScript, + Err: errChan, + } + + // Send the close command for the correct channel and check that a + // shutdown message is sent. + initiator.localCloseChanReqs <- &closeCommand + + var msg lnwire.Message + select { + case outMsg := <-initiator.outgoingQueue: + msg = outMsg.msg + case <-time.After(timeout): + t.Fatalf("did not receive shutdown message") + case err := <-errChan: + // Fail if we do not expect an error. + if err != test.expectedError { + t.Fatalf("error closing channel: %v", err) + } + + // Terminate the test early if have received an error, no + // further action is expected. + return + } + + // Check that we have received a shutdown message. + shutdownMsg, ok := msg.(*lnwire.Shutdown) + if !ok { + t.Fatalf("expected shutdown message, got %T", msg) + } + + // If the test has not specified an expected address, do not check + // whether the shutdown address matches. This covers the case where + // we epect shutdown to a random address and cannot match it. + if len(test.expectedScript) == 0 { + return + } + + // Check that the Shutdown message includes the expected delivery + // script. + if !bytes.Equal(test.expectedScript, shutdownMsg.Address) { + t.Fatalf("expected delivery script: %x, got: %x", + test.expectedScript, shutdownMsg.Address) + } + }) + } +} + +// genScript creates a script paying out to the address provided, which must +// be a valid address. +func genScript(t *testing.T, address string) lnwire.DeliveryAddress { + // Generate an address which can be used for testing. + deliveryAddr, err := btcutil.DecodeAddress( + address, + activeNetParams.Params, + ) + if err != nil { + t.Fatalf("invalid delivery address: %v", err) + } + + script, err := txscript.PayToAddrScript(deliveryAddr) + if err != nil { + t.Fatalf("cannot create script: %v", err) + } + + return script +} diff --git a/rpcserver.go b/rpcserver.go index bf2a12a8..cf0e8bd4 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -1881,8 +1881,28 @@ func (r *rpcServer) CloseChannel(in *lnrpc.CloseChannelRequest, // cooperative channel closure. So we'll forward the request to // the htlc switch which will handle the negotiation and // broadcast details. + + var deliveryScript lnwire.DeliveryAddress + + // If a delivery address to close out to was specified, decode it. + if len(in.DeliveryAddress) > 0 { + // Decode the address provided. + addr, err := btcutil.DecodeAddress( + in.DeliveryAddress, activeNetParams.Params, + ) + if err != nil { + return fmt.Errorf("invalid delivery address: %v", err) + } + + // Create a script to pay out to the address provided. + deliveryScript, err = txscript.PayToAddrScript(addr) + if err != nil { + return err + } + } + updateChan, errChan = r.server.htlcSwitch.CloseLink( - chanPoint, htlcswitch.CloseRegular, feeRate, + chanPoint, htlcswitch.CloseRegular, feeRate, deliveryScript, ) } out: diff --git a/server.go b/server.go index 4c89aa47..dd8fadc2 100644 --- a/server.go +++ b/server.go @@ -824,7 +824,11 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, closureType htlcswitch.ChannelCloseType) { // TODO(conner): Properly respect the update and error channels // returned by CloseLink. - s.htlcSwitch.CloseLink(chanPoint, closureType, 0) + + // Instruct the switch to close the channel. Provide no close out + // delivery script or target fee per kw because user input is not + // available when the remote peer closes the channel. + s.htlcSwitch.CloseLink(chanPoint, closureType, 0, nil) } // We will use the following channel to reliably hand off contract diff --git a/test_utils.go b/test_utils.go index f6f688c1..afb16cd5 100644 --- a/test_utils.go +++ b/test_utils.go @@ -90,10 +90,16 @@ var ( } ) +// noUpdate is a function which can be used as a parameter in createTestPeer to +// call the setup code with no custom values on the channels set up. +var noUpdate = func(a, b *channeldb.OpenChannel) {} + // createTestPeer creates a channel between two nodes, and returns a peer for -// one of the nodes, together with the channel seen from both nodes. -func createTestPeer(notifier chainntnfs.ChainNotifier, - publTx chan *wire.MsgTx) (*peer, *lnwallet.LightningChannel, +// one of the nodes, together with the channel seen from both nodes. It takes +// an updateChan function which can be used to modify the default values on +// the channel states for each peer. +func createTestPeer(notifier chainntnfs.ChainNotifier, publTx chan *wire.MsgTx, + updateChan func(a, b *channeldb.OpenChannel)) (*peer, *lnwallet.LightningChannel, *lnwallet.LightningChannel, func(), error) { aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), @@ -285,6 +291,9 @@ func createTestPeer(notifier chainntnfs.ChainNotifier, Packager: channeldb.NewChannelPackager(shortChanID), } + // Set custom values on the channel states. + updateChan(aliceChannelState, bobChannelState) + aliceAddr := &net.TCPAddr{ IP: net.ParseIP("127.0.0.1"), Port: 18555,