diff --git a/chancloser.go b/chancloser.go index 257b6db1..786f6293 100644 --- a/chancloser.go +++ b/chancloser.go @@ -78,6 +78,10 @@ type chanCloseCfg struct { // broadcastTx broadcasts the passed transaction to the network. broadcastTx func(*wire.MsgTx) error + // disableChannel disables a channel, resulting in it not being able to + // forward payments. + disableChannel func(wire.OutPoint) error + // quit is a channel that should be sent upon in the occasion the state // machine should cease all progress and shutdown. quit chan struct{} @@ -436,6 +440,16 @@ func (c *channelCloser) ProcessCloseMsg(msg lnwire.Message) ([]lnwire.Message, b return nil, false, err } + // We'll attempt to disable the channel in the background to + // avoid blocking due to sending the update message to all + // active peers. + go func() { + if err := c.cfg.disableChannel(c.chanPoint); err != nil { + peerLog.Errorf("Unable to disable channel %v on "+ + "close: %v", c.chanPoint, err) + } + }() + // Finally, we'll transition to the closeFinished state, and // also return the final close signed message we sent. // Additionally, we return true for the second argument to diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index 9fc3023a..1f29fbce 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -122,6 +122,10 @@ type ChainArbitratorConfig struct { // ChainIO allows us to query the state of the current main chain. ChainIO lnwallet.BlockChainIO + + // DisableChannel disables a channel, resulting in it not being able to + // forward payments. + DisableChannel func(wire.OutPoint) error } // ChainArbitrator is a sub-system that oversees the on-chain resolution of all @@ -667,6 +671,16 @@ func (c *ChainArbitrator) ForceCloseContract(chanPoint wire.OutPoint) (*wire.Msg return nil, fmt.Errorf("ChainArbitrator shutting down") } + // We'll attempt to disable the channel in the background to + // avoid blocking due to sending the update message to all + // active peers. + go func() { + if err := c.cfg.DisableChannel(chanPoint); err != nil { + log.Errorf("Unable to disable channel %v on "+ + "close: %v", chanPoint, err) + } + }() + return closeTx, nil } diff --git a/peer.go b/peer.go index 8c9afc5e..28c8d930 100644 --- a/peer.go +++ b/peer.go @@ -1677,6 +1677,7 @@ func (p *peer) fetchActiveChanCloser(chanID lnwire.ChannelID) (*channelCloser, e channel: channel, unregisterChannel: p.server.htlcSwitch.RemoveLink, broadcastTx: p.server.cc.wallet.PublishTransaction, + disableChannel: p.server.disableChannel, quit: p.quit, }, deliveryAddr, @@ -1733,11 +1734,13 @@ func (p *peer) handleLocalCloseReq(req *htlcswitch.ChanClose) { req.Err <- err return } + chanCloser := newChannelCloser( chanCloseCfg{ channel: channel, unregisterChannel: p.server.htlcSwitch.RemoveLink, broadcastTx: p.server.cc.wallet.PublishTransaction, + disableChannel: p.server.disableChannel, quit: p.quit, }, deliveryAddr, @@ -2013,8 +2016,8 @@ func fetchLastChanUpdate(s *server, } if edge1 == nil || edge2 == nil { - return nil, errors.Errorf("unable to find "+ - "channel by ShortChannelID(%v)", cid) + return nil, fmt.Errorf("unable to find channel by "+ + "ShortChannelID(%v)", cid) } // If we're the outgoing node on the first edge, then that @@ -2027,27 +2030,6 @@ func fetchLastChanUpdate(s *server, local = edge1 } - update := lnwire.ChannelUpdate{ - ChainHash: info.ChainHash, - ShortChannelID: lnwire.NewShortChanIDFromInt(local.ChannelID), - Timestamp: uint32(local.LastUpdate.Unix()), - Flags: local.Flags, - TimeLockDelta: local.TimeLockDelta, - HtlcMinimumMsat: local.MinHTLC, - BaseFee: uint32(local.FeeBaseMSat), - FeeRate: uint32(local.FeeProportionalMillionths), - } - update.Signature, err = lnwire.NewSigFromRawSignature(local.SigBytes) - if err != nil { - return nil, err - } - - hswcLog.Tracef("Sending latest channel_update: %v", - newLogClosure(func() string { - return spew.Sdump(update) - }), - ) - - return &update, nil + return extractChannelUpdate(info, local) } } diff --git a/server.go b/server.go index 2163fd1b..22b315c1 100644 --- a/server.go +++ b/server.go @@ -661,6 +661,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, return ErrServerShuttingDown } }, + DisableChannel: s.disableChannel, }, chanDB) s.breachArbiter = newBreachArbiter(&BreachConfig{ @@ -2588,3 +2589,122 @@ func (s *server) fetchNodeAdvertisedAddr(pub *btcec.PublicKey) (net.Addr, error) return node.Addresses[0], nil } + +// disableChannel disables a channel, resulting in it not being able to forward +// payments. This is done by sending a new channel update across the network +// with the disabled flag set. +func (s *server) disableChannel(op wire.OutPoint) error { + // Retrieve the latest update for this channel. We'll use this + // as our starting point to send the new update. + chanUpdate, err := s.fetchLastChanUpdateByOutPoint(op) + if err != nil { + return err + } + + // Set the bit responsible for marking a channel as disabled. + chanUpdate.Flags |= lnwire.ChanUpdateDisabled + + // We must now update the message's timestamp and generate a new + // signature. + chanUpdate.Timestamp = uint32(time.Now().Unix()) + + chanUpdateMsg, err := chanUpdate.DataToSign() + if err != nil { + return err + } + + pubKey := s.identityPriv.PubKey() + sig, err := s.nodeSigner.SignMessage(pubKey, chanUpdateMsg) + if err != nil { + return err + } + chanUpdate.Signature, err = lnwire.NewSigFromSignature(sig) + if err != nil { + return err + } + + // Once signed, we'll send the new update to all of our peers. + return s.applyChannelUpdate(chanUpdate) +} + +// fetchLastChanUpdateByOutPoint fetches the latest update for a channel from +// our point of view. +func (s *server) fetchLastChanUpdateByOutPoint(op wire.OutPoint) (*lnwire.ChannelUpdate, error) { + graph := s.chanDB.ChannelGraph() + info, edge1, edge2, err := graph.FetchChannelEdgesByOutpoint(&op) + if err != nil { + return nil, err + } + + if edge1 == nil || edge2 == nil { + return nil, fmt.Errorf("unable to find channel(%v)", op) + } + + // If we're the outgoing node on the first edge, then that + // means the second edge is our policy. Otherwise, the first + // edge is our policy. + var local *channeldb.ChannelEdgePolicy + + ourPubKey := s.identityPriv.PubKey().SerializeCompressed() + if bytes.Equal(edge1.Node.PubKeyBytes[:], ourPubKey) { + local = edge2 + } else { + local = edge1 + } + + return extractChannelUpdate(info, local) +} + +// extractChannelUpdate retrieves a lnwire.ChannelUpdate message from an edge's +// info and routing policy. +func extractChannelUpdate(info *channeldb.ChannelEdgeInfo, + policy *channeldb.ChannelEdgePolicy) (*lnwire.ChannelUpdate, error) { + + update := &lnwire.ChannelUpdate{ + ChainHash: info.ChainHash, + ShortChannelID: lnwire.NewShortChanIDFromInt(policy.ChannelID), + Timestamp: uint32(policy.LastUpdate.Unix()), + Flags: policy.Flags, + TimeLockDelta: policy.TimeLockDelta, + HtlcMinimumMsat: policy.MinHTLC, + BaseFee: uint32(policy.FeeBaseMSat), + FeeRate: uint32(policy.FeeProportionalMillionths), + } + + var err error + update.Signature, err = lnwire.NewSigFromRawSignature(policy.SigBytes) + if err != nil { + return nil, err + } + + return update, nil +} + +// applyChannelUpdate applies the channel update to the different sub-systems of +// the server. +func (s *server) applyChannelUpdate(update *lnwire.ChannelUpdate) error { + newChannelPolicy := &channeldb.ChannelEdgePolicy{ + SigBytes: update.Signature.ToSignatureBytes(), + ChannelID: update.ShortChannelID.ToUint64(), + LastUpdate: time.Unix(int64(update.Timestamp), 0), + Flags: update.Flags, + TimeLockDelta: update.TimeLockDelta, + MinHTLC: update.HtlcMinimumMsat, + FeeBaseMSat: lnwire.MilliSatoshi(update.BaseFee), + FeeProportionalMillionths: lnwire.MilliSatoshi(update.FeeRate), + } + + err := s.chanRouter.UpdateEdge(newChannelPolicy) + if err != nil && !routing.IsError(err, routing.ErrIgnored) { + return err + } + + pubKey := s.identityPriv.PubKey() + errChan := s.authGossiper.ProcessLocalAnnouncement(update, pubKey) + select { + case err := <-errChan: + return err + case <-s.quit: + return ErrServerShuttingDown + } +}