diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 9298ceb1..afdd50cc 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -1272,72 +1272,6 @@ func (l *channelLink) handleDownstreamPkt(pkt *htlcPacket) { l.log.Warnf("Unable to handle downstream add HTLC: %v", err) - var ( - localFailure = false - reason lnwire.OpaqueReason - ) - - // Create a temporary channel failure which we will send - // back to our peer if this is a forward, or report to - // the user if the failed payment was locally initiated. - failure := l.createFailureWithUpdate( - func(upd *lnwire.ChannelUpdate) lnwire.FailureMessage { - return lnwire.NewTemporaryChannelFailure( - upd, - ) - }, - ) - - // If the payment was locally initiated (which is - // indicated by a nil obfuscator), we do not need to - // encrypt it back to the sender. - if pkt.obfuscator == nil { - var b bytes.Buffer - err := lnwire.EncodeFailure(&b, failure, 0) - if err != nil { - l.log.Errorf("unable to encode "+ - "failure: %v", err) - l.mailBox.AckPacket(pkt.inKey()) - return - } - reason = lnwire.OpaqueReason(b.Bytes()) - localFailure = true - } else { - // If the packet is part of a forward, - // (identified by a non-nil obfuscator) we need - // to encrypt the error back to the source. - var err error - reason, err = pkt.obfuscator.EncryptFirstHop(failure) - if err != nil { - l.log.Errorf("unable to "+ - "obfuscate error: %v", err) - l.mailBox.AckPacket(pkt.inKey()) - return - } - } - - // Create a link error containing the temporary channel - // failure and a detail which indicates the we failed to - // add the htlc. - linkError := NewDetailedLinkError( - failure, OutgoingFailureDownstreamHtlcAdd, - ) - - failPkt := &htlcPacket{ - incomingChanID: pkt.incomingChanID, - incomingHTLCID: pkt.incomingHTLCID, - circuit: pkt.circuit, - sourceRef: pkt.sourceRef, - hasSource: true, - localFailure: localFailure, - linkFailure: linkError, - htlc: &lnwire.UpdateFailHTLC{ - Reason: reason, - }, - } - - go l.forwardBatch(failPkt) - // Remove this packet from the link's mailbox, this // prevents it from being reprocessed if the link // restarts and resets it mailbox. If this response @@ -1346,7 +1280,7 @@ func (l *channelLink) handleDownstreamPkt(pkt *htlcPacket) { // the switch, since the circuit was never fully opened, // and the forwarding package shows it as // unacknowledged. - l.mailBox.AckPacket(pkt.inKey()) + l.mailBox.FailAdd(pkt) return } diff --git a/htlcswitch/mailbox.go b/htlcswitch/mailbox.go index e32d75a7..54c918b8 100644 --- a/htlcswitch/mailbox.go +++ b/htlcswitch/mailbox.go @@ -1,6 +1,7 @@ package htlcswitch import ( + "bytes" "container/list" "errors" "sync" @@ -31,8 +32,17 @@ type MailBox interface { // AckPacket removes a packet from the mailboxes in-memory replay // buffer. This will prevent a packet from being delivered after a link - // restarts if the switch has remained online. - AckPacket(CircuitKey) + // restarts if the switch has remained online. The returned boolean + // indicates whether or not a packet with the passed incoming circuit + // key was removed. + AckPacket(CircuitKey) bool + + // FailAdd fails an UpdateAddHTLC that exists within the mailbox, + // removing it from the in-memory replay buffer. This will prevent the + // packet from being delivered after the link restarts if the switch has + // remained online. The generated LinkError will show an + // OutgoingFailureDownstreamHtlcAdd FailureDetail. + FailAdd(pkt *htlcPacket) // MessageOutBox returns a channel that any new messages ready for // delivery will be sent on. @@ -56,12 +66,29 @@ type MailBox interface { Stop() } +type mailBoxConfig struct { + // shortChanID is the short channel id of the channel this mailbox + // belongs to. + shortChanID lnwire.ShortChannelID + + // fetchUpdate retreives the most recent channel update for the channel + // this mailbox belongs to. + fetchUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) + + // forwardPackets send a varidic number of htlcPackets to the switch to + // be routed. A quit channel should be provided so that the call can + // properly exit during shutdown. + forwardPackets func(chan struct{}, ...*htlcPacket) chan error +} + // memoryMailBox is an implementation of the MailBox struct backed by purely // in-memory queues. type memoryMailBox struct { started sync.Once stopped sync.Once + cfg *mailBoxConfig + wireMessages *list.List wireMtx sync.Mutex wireCond *sync.Cond @@ -84,8 +111,9 @@ type memoryMailBox struct { } // newMemoryMailBox creates a new instance of the memoryMailBox. -func newMemoryMailBox() *memoryMailBox { +func newMemoryMailBox(cfg *mailBoxConfig) *memoryMailBox { box := &memoryMailBox{ + cfg: cfg, wireMessages: list.New(), htlcPkts: list.New(), messageOutbox: make(chan lnwire.Message), @@ -179,20 +207,23 @@ func (m *memoryMailBox) signalUntilReset(cType courierType, } // AckPacket removes the packet identified by it's incoming circuit key from the -// queue of packets to be delivered. +// queue of packets to be delivered. The returned boolean indicates whether or +// not a packet with the passed incoming circuit key was removed. // // NOTE: It is safe to call this method multiple times for the same circuit key. -func (m *memoryMailBox) AckPacket(inKey CircuitKey) { +func (m *memoryMailBox) AckPacket(inKey CircuitKey) bool { m.pktCond.L.Lock() entry, ok := m.pktIndex[inKey] if !ok { m.pktCond.L.Unlock() - return + return false } m.htlcPkts.Remove(entry) delete(m.pktIndex, inKey) m.pktCond.L.Unlock() + + return true } // HasPacket queries the packets for a circuit key, this is used to drop packets @@ -410,6 +441,80 @@ func (m *memoryMailBox) AddPacket(pkt *htlcPacket) error { return nil } +// FailAdd fails an UpdateAddHTLC that exists within the mailbox, removing it +// from the in-memory replay buffer. This will prevent the packet from being +// delivered after the link restarts if the switch has remained online. The +// generated LinkError will show an OutgoingFailureDownstreamHtlcAdd +// FailureDetail. +func (m *memoryMailBox) FailAdd(pkt *htlcPacket) { + // First, remove the packet from mailbox. If we didn't find the packet + // because it has already been acked, we'll exit early to avoid sending + // a duplicate fail message through the switch. + if !m.AckPacket(pkt.inKey()) { + return + } + + var ( + localFailure = false + reason lnwire.OpaqueReason + ) + + // Create a temporary channel failure which we will send back to our + // peer if this is a forward, or report to the user if the failed + // payment was locally initiated. + var failure lnwire.FailureMessage + update, err := m.cfg.fetchUpdate(m.cfg.shortChanID) + if err != nil { + failure = &lnwire.FailTemporaryNodeFailure{} + } else { + failure = lnwire.NewTemporaryChannelFailure(update) + } + + // If the payment was locally initiated (which is indicated by a nil + // obfuscator), we do not need to encrypt it back to the sender. + if pkt.obfuscator == nil { + var b bytes.Buffer + err := lnwire.EncodeFailure(&b, failure, 0) + if err != nil { + log.Errorf("Unable to encode failure: %v", err) + return + } + reason = lnwire.OpaqueReason(b.Bytes()) + localFailure = true + } else { + // If the packet is part of a forward, (identified by a non-nil + // obfuscator) we need to encrypt the error back to the source. + var err error + reason, err = pkt.obfuscator.EncryptFirstHop(failure) + if err != nil { + log.Errorf("Unable to obfuscate error: %v", err) + return + } + } + + // Create a link error containing the temporary channel failure and a + // detail which indicates the we failed to add the htlc. + linkError := NewDetailedLinkError( + failure, OutgoingFailureDownstreamHtlcAdd, + ) + + failPkt := &htlcPacket{ + incomingChanID: pkt.incomingChanID, + incomingHTLCID: pkt.incomingHTLCID, + circuit: pkt.circuit, + sourceRef: pkt.sourceRef, + hasSource: true, + localFailure: localFailure, + linkFailure: linkError, + htlc: &lnwire.UpdateFailHTLC{ + Reason: reason, + }, + } + + errChan := m.cfg.forwardPackets(m.quit, failPkt) + go handleBatchFwdErrs(errChan, log) +} + // MessageOutBox returns a channel that any new messages ready for delivery // will be sent on. // @@ -434,6 +539,8 @@ func (m *memoryMailBox) PacketOutBox() chan *htlcPacket { type mailOrchestrator struct { mu sync.RWMutex + cfg *mailOrchConfig + // mailboxes caches exactly one mailbox for all known channels. mailboxes map[lnwire.ChannelID]MailBox @@ -454,9 +561,21 @@ type mailOrchestrator struct { unclaimedPackets map[lnwire.ShortChannelID][]*htlcPacket } +type mailOrchConfig struct { + // forwardPackets send a varidic number of htlcPackets to the switch to + // be routed. A quit channel should be provided so that the call can + // properly exit during shutdown. + forwardPackets func(chan struct{}, ...*htlcPacket) chan error + + // fetchUpdate retreives the most recent channel update for the channel + // this mailbox belongs to. + fetchUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) +} + // newMailOrchestrator initializes a fresh mailOrchestrator. -func newMailOrchestrator() *mailOrchestrator { +func newMailOrchestrator(cfg *mailOrchConfig) *mailOrchestrator { return &mailOrchestrator{ + cfg: cfg, mailboxes: make(map[lnwire.ChannelID]MailBox), liveIndex: make(map[lnwire.ShortChannelID]lnwire.ChannelID), unclaimedPackets: make(map[lnwire.ShortChannelID][]*htlcPacket), @@ -472,7 +591,9 @@ func (mo *mailOrchestrator) Stop() { // GetOrCreateMailBox returns an existing mailbox belonging to `chanID`, or // creates and returns a new mailbox if none is found. -func (mo *mailOrchestrator) GetOrCreateMailBox(chanID lnwire.ChannelID) MailBox { +func (mo *mailOrchestrator) GetOrCreateMailBox(chanID lnwire.ChannelID, + shortChanID lnwire.ShortChannelID) MailBox { + // First, try lookup the mailbox directly using only the shared mutex. mo.mu.RLock() mailbox, ok := mo.mailboxes[chanID] @@ -485,7 +606,7 @@ func (mo *mailOrchestrator) GetOrCreateMailBox(chanID lnwire.ChannelID) MailBox // Otherwise, we will try again with exclusive lock, creating a mailbox // if one still has not been created. mo.mu.Lock() - mailbox = mo.exclusiveGetOrCreateMailBox(chanID) + mailbox = mo.exclusiveGetOrCreateMailBox(chanID, shortChanID) mo.mu.Unlock() return mailbox @@ -497,11 +618,15 @@ func (mo *mailOrchestrator) GetOrCreateMailBox(chanID lnwire.ChannelID) MailBox // // NOTE: This method MUST be invoked with the mailOrchestrator's exclusive lock. func (mo *mailOrchestrator) exclusiveGetOrCreateMailBox( - chanID lnwire.ChannelID) MailBox { + chanID lnwire.ChannelID, shortChanID lnwire.ShortChannelID) MailBox { mailbox, ok := mo.mailboxes[chanID] if !ok { - mailbox = newMemoryMailBox() + mailbox = newMemoryMailBox(&mailBoxConfig{ + shortChanID: shortChanID, + fetchUpdate: mo.cfg.fetchUpdate, + forwardPackets: mo.cfg.forwardPackets, + }) mailbox.Start() mo.mailboxes[chanID] = mailbox } @@ -581,7 +706,7 @@ func (mo *mailOrchestrator) Deliver( // index should only be set if the mailbox had been initialized // beforehand. However, this does ensure that this case is // handled properly in the event that it could happen. - mailbox = mo.exclusiveGetOrCreateMailBox(chanID) + mailbox = mo.exclusiveGetOrCreateMailBox(chanID, sid) mo.mu.Unlock() // Deliver the packet to the mailbox if it was found or created. diff --git a/htlcswitch/mailbox_test.go b/htlcswitch/mailbox_test.go index 83dfc282..ccf35e87 100644 --- a/htlcswitch/mailbox_test.go +++ b/htlcswitch/mailbox_test.go @@ -19,7 +19,7 @@ func TestMailBoxCouriers(t *testing.T) { // First, we'll create new instance of the current default mailbox // type. - mailBox := newMemoryMailBox() + mailBox := newMemoryMailBox(&mailBoxConfig{}) mailBox.Start() defer mailBox.Stop() @@ -153,7 +153,7 @@ func TestMailBoxCouriers(t *testing.T) { func TestMailBoxResetAfterShutdown(t *testing.T) { t.Parallel() - m := newMemoryMailBox() + m := newMemoryMailBox(&mailBoxConfig{}) m.Start() // Stop the mailbox, then try to reset the message and packet couriers. @@ -170,6 +170,144 @@ func TestMailBoxResetAfterShutdown(t *testing.T) { } } +type mailboxContext struct { + t *testing.T + mailbox MailBox + forwards chan *htlcPacket +} + +func newMailboxContext(t *testing.T) *mailboxContext { + + ctx := &mailboxContext{ + t: t, + forwards: make(chan *htlcPacket, 1), + } + ctx.mailbox = newMemoryMailBox(&mailBoxConfig{ + fetchUpdate: func(sid lnwire.ShortChannelID) ( + *lnwire.ChannelUpdate, error) { + return &lnwire.ChannelUpdate{ + ShortChannelID: sid, + }, nil + }, + forwardPackets: ctx.forward, + }) + ctx.mailbox.Start() + + return ctx +} + +func (c *mailboxContext) forward(_ chan struct{}, + pkts ...*htlcPacket) chan error { + + for _, pkt := range pkts { + c.forwards <- pkt + } + + errChan := make(chan error) + close(errChan) + + return errChan +} + +func (c *mailboxContext) sendAdds(start, num int) []*htlcPacket { + c.t.Helper() + + sentPackets := make([]*htlcPacket, num) + for i := 0; i < num; i++ { + pkt := &htlcPacket{ + outgoingChanID: lnwire.NewShortChanIDFromInt( + uint64(prand.Int63())), + incomingChanID: lnwire.NewShortChanIDFromInt( + uint64(prand.Int63())), + incomingHTLCID: uint64(start + i), + amount: lnwire.MilliSatoshi(prand.Int63()), + htlc: &lnwire.UpdateAddHTLC{ + ID: uint64(start + i), + }, + } + sentPackets[i] = pkt + + err := c.mailbox.AddPacket(pkt) + if err != nil { + c.t.Fatalf("unable to add packet: %v", err) + } + } + + return sentPackets +} + +func (c *mailboxContext) receivePkts(pkts []*htlcPacket) { + c.t.Helper() + + for i, expPkt := range pkts { + select { + case pkt := <-c.mailbox.PacketOutBox(): + if reflect.DeepEqual(expPkt, pkt) { + continue + } + + c.t.Fatalf("inkey mismatch #%d, want: %v vs "+ + "got: %v", i, expPkt.inKey(), pkt.inKey()) + + case <-time.After(50 * time.Millisecond): + c.t.Fatalf("did not receive fail for index %d", i) + } + } +} + +func (c *mailboxContext) checkFails(adds []*htlcPacket) { + c.t.Helper() + + for i, add := range adds { + select { + case fail := <-c.forwards: + if add.inKey() == fail.inKey() { + continue + } + c.t.Fatalf("inkey mismatch #%d, add: %v vs fail: %v", + i, add.inKey(), fail.inKey()) + + case <-time.After(50 * time.Millisecond): + c.t.Fatalf("did not receive fail for index %d", i) + } + } + + select { + case pkt := <-c.forwards: + c.t.Fatalf("unexpected forward: %v", pkt) + case <-time.After(50 * time.Millisecond): + } +} + +// TestMailBoxFailAdd asserts that FailAdd returns a response to the switch +// under various interleavings with other operations on the mailbox. +func TestMailBoxFailAdd(t *testing.T) { + ctx := newMailboxContext(t) + defer ctx.mailbox.Stop() + + failAdds := func(adds []*htlcPacket) { + for _, add := range adds { + ctx.mailbox.FailAdd(add) + } + } + + const numBatchPackets = 5 + + // Send 10 adds, and pull them from the mailbox. + adds := ctx.sendAdds(0, numBatchPackets) + ctx.receivePkts(adds) + + // Fail all of these adds, simulating an error adding the HTLCs to the + // commitment. We should see a failure message for each. + go failAdds(adds) + ctx.checkFails(adds) + + // As a sanity check, Fail all of them again and assert that no + // duplicate fails are sent. + go failAdds(adds) + ctx.checkFails(nil) +} + // TestMailOrchestrator asserts that the orchestrator properly buffers packets // for channels that haven't been made live, such that they are delivered // immediately after BindLiveShortChanID. It also tests that packets are delivered @@ -178,7 +316,7 @@ func TestMailOrchestrator(t *testing.T) { t.Parallel() // First, we'll create a new instance of our orchestrator. - mo := newMailOrchestrator() + mo := newMailOrchestrator(&mailOrchConfig{}) defer mo.Stop() // We'll be delivering 10 htlc packets via the orchestrator. @@ -203,7 +341,7 @@ func TestMailOrchestrator(t *testing.T) { } // Now, initialize a new mailbox for Alice's chanid. - mailbox := mo.GetOrCreateMailBox(chanID1) + mailbox := mo.GetOrCreateMailBox(chanID1, aliceChanID) // Verify that no messages are received, since Alice's mailbox has not // been made live. @@ -248,7 +386,7 @@ func TestMailOrchestrator(t *testing.T) { // For the second half of the test, create a new mailbox for Bob and // immediately make it live with an assigned short chan id. - mailbox = mo.GetOrCreateMailBox(chanID2) + mailbox = mo.GetOrCreateMailBox(chanID2, bobChanID) mo.BindLiveShortChanID(mailbox, chanID2, bobChanID) // Create the second half of our htlcs, and deliver them via the diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 78ff1005..ba57ed11 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -283,12 +283,11 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { return nil, err } - return &Switch{ + s := &Switch{ bestHeight: currentHeight, cfg: &cfg, circuits: circuitMap, linkIndex: make(map[lnwire.ChannelID]ChannelLink), - mailOrchestrator: newMailOrchestrator(), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), interfaceIndex: make(map[[33]byte]map[lnwire.ChannelID]ChannelLink), pendingLinkIndex: make(map[lnwire.ChannelID]ChannelLink), @@ -297,7 +296,14 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { chanCloseRequests: make(chan *ChanClose), resolutionMsgs: make(chan *resolutionMsg), quit: make(chan struct{}), - }, nil + } + + s.mailOrchestrator = newMailOrchestrator(&mailOrchConfig{ + fetchUpdate: s.cfg.FetchLastChannelUpdate, + forwardPackets: s.ForwardPackets, + }) + + return s, nil } // resolutionMsg is a struct that wraps an existing ResolutionMsg with a done @@ -2037,7 +2043,8 @@ func (s *Switch) AddLink(link ChannelLink) error { // Get and attach the mailbox for this link, which buffers packets in // case there packets that we tried to deliver while this link was // offline. - mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID) + shortChanID := link.ShortChanID() + mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID, shortChanID) link.AttachMailBox(mailbox) if err := link.Start(); err != nil { @@ -2045,7 +2052,6 @@ func (s *Switch) AddLink(link ChannelLink) error { return err } - shortChanID := link.ShortChanID() if shortChanID == hop.Source { log.Infof("Adding pending link chan_id=%v, short_chan_id=%v", chanID, shortChanID) @@ -2217,7 +2223,7 @@ func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID) error { // Finally, alert the mail orchestrator to the change of short channel // ID, and deliver any unclaimed packets to the link. - mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID) + mailbox := s.mailOrchestrator.GetOrCreateMailBox(chanID, shortChanID) s.mailOrchestrator.BindLiveShortChanID( mailbox, chanID, shortChanID, )