From de3af9b0c01799cc0cb940ed6fed0bb94b71a3ff Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 10 Nov 2017 14:52:27 -0800 Subject: [PATCH] htlcswitch: modify Bandwidth() method on links to use more accurate accoutning MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In this commit, we modify the existing implementation of the Bandwidth() method on the default ChannelLink implementation to use much tighter accounting. Before this commit, there was a bug wherein if the link restarted with pending un-settled HTLC’s, and one of them was settled, then the bandwidth wouldn’t properly be updated to reflect this fact. To fix this, we’ve done away with the manual accounting and instead grab the current balances from two sources: the set of active HTLC’s within the overflow queue, and the report from the link itself which includes the pending HTLC’s and factors in the amount we’d need to (or not need to) pay in fees for each HTLC. --- htlcswitch/link.go | 59 +++++++++----------------------------------- htlcswitch/packet.go | 2 ++ htlcswitch/queue.go | 16 ++++++++++++ 3 files changed, 29 insertions(+), 48 deletions(-) diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 13dfe8ed..c2ca512e 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -225,7 +225,6 @@ func NewChannelLink(cfg ChannelLinkConfig, channel *lnwallet.LightningChannel, mailBox: newMemoryMailBox(), linkControl: make(chan interface{}), // TODO(roasbeef): just do reserve here? - availableBandwidth: uint64(channel.StateSnapshot().LocalBalance), logCommitTimer: time.NewTimer(300 * time.Millisecond), overflowQueue: newPacketQueue(lnwallet.MaxHTLCNumber / 2), bestHeight: currentHeight, @@ -255,6 +254,7 @@ func (l *channelLink) Start() error { log.Infof("ChannelLink(%v) is starting", l) + l.mailBox.Start() l.overflowQueue.Start() l.wg.Add(1) @@ -277,6 +277,7 @@ func (l *channelLink) Stop() { l.channel.Stop() + l.mailBox.Stop() l.overflowQueue.Stop() close(l.quit) @@ -463,12 +464,6 @@ out: htlc.PaymentHash[:], l.batchCounter) - // As we're adding a new pkt to the overflow - // queue, decrement the available bandwidth. - atomic.AddUint64( - &l.availableBandwidth, - -uint64(htlc.Amount), - ) l.overflowQueue.AddPkt(pkt) continue } @@ -541,16 +536,6 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { htlc.PaymentHash[:], l.batchCounter) - // If we're processing this HTLC for the first - // time, then we'll decrement the available - // bandwidth. As otherwise, we'd double count - // the effect of the HTLC. - if !isReProcess { - atomic.AddUint64( - &l.availableBandwidth, -uint64(htlc.Amount), - ) - } - l.overflowQueue.AddPkt(pkt) return @@ -603,8 +588,6 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { isObfuscated, ) - atomic.AddUint64(&l.availableBandwidth, uint64(htlc.Amount)) - // TODO(roasbeef): need to identify if sent // from switch so don't need to obfuscate go l.cfg.Switch.forward(failPkt) @@ -613,14 +596,6 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { } } - // If we're processing this HTLC for the first time, then we'll - // decrement the available bandwidth. - if !isReProcess { - // Subtract the available bandwidth as we have a new - // HTLC in limbo. - atomic.AddUint64(&l.availableBandwidth, -uint64(htlc.Amount)) - } - log.Tracef("Received downstream htlc: payment_hash=%x, "+ "local_log_index=%v, batch_size=%v", htlc.PaymentHash[:], index, l.batchCounter+1) @@ -633,17 +608,13 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { // upstream. Therefore we settle the HTLC within the our local // state machine. pre := htlc.PaymentPreimage - logIndex, amt, err := l.channel.SettleHTLC(pre) + logIndex, _, err := l.channel.SettleHTLC(pre) if err != nil { // TODO(roasbeef): broadcast on-chain l.fail("unable to settle incoming HTLC: %v", err) return } - // Increment the available bandwidth as we've settled an HTLC - // extended by tbe remote party. - atomic.AddUint64(&l.availableBandwidth, uint64(amt)) - // With the HTLC settled, we'll need to populate the wire // message to target the specific channel and HTLC to be // cancelled. @@ -778,19 +749,15 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { // If remote side have been unable to parse the onion blob we // have sent to it, than we should transform the malformed HTLC // message to the usual HTLC fail message. - amt, err := l.channel.ReceiveFailHTLC(idx, b.Bytes()) + _, err := l.channel.ReceiveFailHTLC(msg.ID, b.Bytes()) if err != nil { l.fail("unable to handle upstream fail HTLC: %v", err) return } - // Increment the available bandwidth as they've removed our - // HTLC. - atomic.AddUint64(&l.availableBandwidth, uint64(amt)) - case *lnwire.UpdateFailHTLC: idx := msg.ID - amt, err := l.channel.ReceiveFailHTLC(idx, msg.Reason[:]) + _, err := l.channel.ReceiveFailHTLC(idx, msg.Reason[:]) if err != nil { l.fail("unable to handle upstream fail HTLC: %v", err) return @@ -963,8 +930,11 @@ type getBandwidthCmd struct { // // NOTE: Part of the ChannelLink interface. func (l *channelLink) Bandwidth() lnwire.MilliSatoshi { - // TODO(roasbeef): subtract reserverj - return lnwire.MilliSatoshi(atomic.LoadUint64(&l.availableBandwidth)) + // TODO(roasbeef): subtract reserve + channelBandwidth := l.channel.AvailableBalance() + overflowBandwidth := l.overflowQueue.TotalHtlcAmount() + + return channelBandwidth - overflowBandwidth } // policyUpdate is a message sent to a channel link when an outside sub-system @@ -1276,19 +1246,12 @@ func (l *channelLink) processLockedInHtlcs( } preimage := invoice.Terms.PaymentPreimage - logIndex, amt, err := l.channel.SettleHTLC(preimage) + logIndex, _, err := l.channel.SettleHTLC(preimage) if err != nil { l.fail("unable to settle htlc: %v", err) return nil } - // Increment the available bandwidth as we've - // settled an HTLC extended by tbe remote - // party. - atomic.AddUint64( - &l.availableBandwidth, uint64(amt), - ) - // Notify the invoiceRegistry of the invoices // we just settled with this latest commitment // update. diff --git a/htlcswitch/packet.go b/htlcswitch/packet.go index b9cd8365..540335fe 100644 --- a/htlcswitch/packet.go +++ b/htlcswitch/packet.go @@ -51,6 +51,7 @@ type htlcPacket struct { func newInitPacket(destNode [33]byte, htlc *lnwire.UpdateAddHTLC) *htlcPacket { return &htlcPacket{ destNode: destNode, + amount: htlc.Amount, htlc: htlc, } } @@ -61,6 +62,7 @@ func newAddPacket(src, dest lnwire.ShortChannelID, htlc *lnwire.UpdateAddHTLC, e ErrorEncrypter) *htlcPacket { return &htlcPacket{ + amount: htlc.Amount, dest: dest, src: src, htlc: htlc, diff --git a/htlcswitch/queue.go b/htlcswitch/queue.go index 58ac10fe..22a7c382 100644 --- a/htlcswitch/queue.go +++ b/htlcswitch/queue.go @@ -3,6 +3,8 @@ package htlcswitch import ( "sync" "sync/atomic" + + "github.com/lightningnetwork/lnd/lnwire" ) // packetQueue is an goroutine-safe queue of htlc packets which over flow the @@ -23,6 +25,11 @@ type packetQueue struct { // with the lock held. queueLen int32 + // totalHtlcAmt is the sum of the value of all pending HTLC's currently + // residing within the overflow queue. This value should only read or + // modified *atomically*. + totalHtlcAmt int64 + queueCond *sync.Cond queueMtx sync.Mutex queue []*htlcPacket @@ -125,6 +132,7 @@ func (p *packetQueue) packetCoordinator() { p.queue[0] = nil p.queue = p.queue[1:] atomic.AddInt32(&p.queueLen, -1) + atomic.AddInt64(&p.totalHtlcAmt, int64(-nextPkt.amount)) p.queueCond.L.Unlock() case <-p.quit: return @@ -147,6 +155,7 @@ func (p *packetQueue) AddPkt(pkt *htlcPacket) { p.queueCond.L.Lock() p.queue = append(p.queue, pkt) atomic.AddInt32(&p.queueLen, 1) + atomic.AddInt64(&p.totalHtlcAmt, int64(pkt.amount)) p.queueCond.L.Unlock() // With the message added, we signal to the msgConsumer that there are @@ -180,3 +189,10 @@ func (p *packetQueue) SignalFreeSlot() { func (p *packetQueue) Length() int32 { return atomic.LoadInt32(&p.queueLen) } + +// TotalHtlcAmount is the total amount (in mSAT) of all HTLC's currently +// residing within the overflow queue. +func (p *packetQueue) TotalHtlcAmount() lnwire.MilliSatoshi { + // TODO(roasbeef): also factor in fee rate? + return lnwire.MilliSatoshi(atomic.LoadInt64(&p.totalHtlcAmt)) +}