diff --git a/htlcswitch/mailbox.go b/htlcswitch/mailbox.go index 1c1ad5c4..4eff7026 100644 --- a/htlcswitch/mailbox.go +++ b/htlcswitch/mailbox.go @@ -4,6 +4,7 @@ import ( "bytes" "container/list" "errors" + "fmt" "sync" "time" @@ -108,8 +109,13 @@ type memoryMailBox struct { htlcPkts *list.List pktIndex map[CircuitKey]*list.Element pktHead *list.Element - pktMtx sync.Mutex - pktCond *sync.Cond + + addPkts *list.List + addIndex map[CircuitKey]*list.Element + addHead *list.Element + + pktMtx sync.Mutex + pktCond *sync.Cond pktOutbox chan *htlcPacket pktReset chan chan struct{} @@ -125,11 +131,13 @@ func newMemoryMailBox(cfg *mailBoxConfig) *memoryMailBox { cfg: cfg, wireMessages: list.New(), htlcPkts: list.New(), + addPkts: list.New(), messageOutbox: make(chan lnwire.Message), pktOutbox: make(chan *htlcPacket), msgReset: make(chan chan struct{}, 1), pktReset: make(chan chan struct{}, 1), pktIndex: make(map[CircuitKey]*list.Element), + addIndex: make(map[CircuitKey]*list.Element), wireShutdown: make(chan struct{}), pktShutdown: make(chan struct{}), quit: make(chan struct{}), @@ -222,24 +230,39 @@ func (m *memoryMailBox) signalUntilReset(cType courierType, // NOTE: It is safe to call this method multiple times for the same circuit key. func (m *memoryMailBox) AckPacket(inKey CircuitKey) bool { m.pktCond.L.Lock() - entry, ok := m.pktIndex[inKey] - if !ok { - m.pktCond.L.Unlock() - return false + defer m.pktCond.L.Unlock() + + if entry, ok := m.pktIndex[inKey]; ok { + // Check whether we are removing the head of the queue. If so, + // we must advance the head to the next packet before removing. + // It's possible that the courier has already advanced the + // pktHead, so this check prevents the pktHead from getting + // desynchronized. + if entry == m.pktHead { + m.pktHead = entry.Next() + } + m.htlcPkts.Remove(entry) + delete(m.pktIndex, inKey) + + return true } - // Check whether we are removing the head of the queue. If so, we must - // advance the head to the next packet before removing. It's possible - // that the courier has already adanced the pktHead, so this check - // prevents the pktHead from getting desynchronized. - if entry == m.pktHead { - m.pktHead = entry.Next() - } - m.htlcPkts.Remove(entry) - delete(m.pktIndex, inKey) - m.pktCond.L.Unlock() + if entry, ok := m.addIndex[inKey]; ok { + // Check whether we are removing the head of the queue. If so, + // we must advance the head to the next add before removing. + // It's possible that the courier has already advanced the + // addHead, so this check prevents the addHead from getting + // desynchronized. + if entry == m.addHead { + m.addHead = entry.Next() + } + m.addPkts.Remove(entry) + delete(m.addIndex, inKey) - return true + return true + } + + return false } // HasPacket queries the packets for a circuit key, this is used to drop packets @@ -328,7 +351,7 @@ func (m *memoryMailBox) mailCourier(cType courierType) { case pktCourier: m.pktCond.L.Lock() - for m.pktHead == nil { + for m.pktHead == nil && m.addHead == nil { m.pktCond.Wait() select { @@ -338,6 +361,7 @@ func (m *memoryMailBox) mailCourier(cType courierType) { // reconnect. case pktDone := <-m.pktReset: m.pktHead = m.htlcPkts.Front() + m.addHead = m.addPkts.Front() close(pktDone) case <-m.quit: @@ -351,6 +375,8 @@ func (m *memoryMailBox) mailCourier(cType courierType) { var ( nextPkt *htlcPacket nextPktEl *list.Element + nextAdd *htlcPacket + nextAddEl *list.Element nextMsg lnwire.Message ) switch cType { @@ -366,8 +392,15 @@ func (m *memoryMailBox) mailCourier(cType courierType) { // doesn't make it into a commitment, then it'll be // re-delivered once the link comes back online. case pktCourier: - nextPkt = m.pktHead.Value.(*htlcPacket) - nextPktEl = m.pktHead + // Peek at the next item to deliver, prioritizing + // Settle/Fail packets over Adds. + if m.pktHead != nil { + nextPkt = m.pktHead.Value.(*htlcPacket) + nextPktEl = m.pktHead + } else { + nextAdd = m.addHead.Value.(*htlcPacket) + nextAddEl = m.addHead + } } // Now that we're done with the condition, we can unlock it to @@ -397,22 +430,56 @@ func (m *memoryMailBox) mailCourier(cType courierType) { } case pktCourier: + var ( + pktOutbox chan *htlcPacket + addOutbox chan *htlcPacket + ) + + // Prioritize delivery of Settle/Fail packets over Adds. + // This ensures that we actively clear the commitment of + // existing HTLCs before trying to add new ones. This + // can help to improve forwarding performance since the + // time to sign a commitment is linear in the number of + // HTLCs manifested on the commitments. + // + // NOTE: Both types are eventually delivered over the + // same channel, but we can control which is delivered + // by exclusively making one nil and the other non-nil. + // We know from our loop condition that at least one + // nextPkt and nextAdd are non-nil. + if nextPkt != nil { + pktOutbox = m.pktOutbox + } else { + addOutbox = m.pktOutbox + } + select { - case m.pktOutbox <- nextPkt: + case pktOutbox <- nextPkt: m.pktCond.L.Lock() - // Only advance the pktHead if this packet - // is still at the head of the queue. + // Only advance the pktHead if this Settle or + // Fail is still at the head of the queue. if m.pktHead != nil && m.pktHead == nextPktEl { m.pktHead = m.pktHead.Next() } m.pktCond.L.Unlock() + case addOutbox <- nextAdd: + m.pktCond.L.Lock() + // Only advance the addHead if this Add is still + // at the head of the queue. + if m.addHead != nil && m.addHead == nextAddEl { + m.addHead = m.addHead.Next() + } + m.pktCond.L.Unlock() + case pktDone := <-m.pktReset: m.pktCond.L.Lock() m.pktHead = m.htlcPkts.Front() + m.addHead = m.addPkts.Front() m.pktCond.L.Unlock() close(pktDone) + case <-m.quit: return } @@ -444,18 +511,38 @@ func (m *memoryMailBox) AddMessage(msg lnwire.Message) error { // NOTE: This method is safe for concrete use and part of the MailBox // interface. func (m *memoryMailBox) AddPacket(pkt *htlcPacket) error { - // First, we'll lock the condition, and add the packet to the end of - // the htlc packet inbox. m.pktCond.L.Lock() - if _, ok := m.pktIndex[pkt.inKey()]; ok { - m.pktCond.L.Unlock() - return nil - } + switch htlc := pkt.htlc.(type) { - entry := m.htlcPkts.PushBack(pkt) - m.pktIndex[pkt.inKey()] = entry - if m.pktHead == nil { - m.pktHead = entry + // Split off Settle/Fail packets into the htlcPkts queue. + case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC: + if _, ok := m.pktIndex[pkt.inKey()]; ok { + m.pktCond.L.Unlock() + return nil + } + + entry := m.htlcPkts.PushBack(pkt) + m.pktIndex[pkt.inKey()] = entry + if m.pktHead == nil { + m.pktHead = entry + } + + // Split off Add packets into the addPkts queue. + case *lnwire.UpdateAddHTLC: + if _, ok := m.addIndex[pkt.inKey()]; ok { + m.pktCond.L.Unlock() + return nil + } + + entry := m.addPkts.PushBack(pkt) + m.addIndex[pkt.inKey()] = entry + if m.addHead == nil { + m.addHead = entry + } + + default: + m.pktCond.L.Unlock() + return fmt.Errorf("unknown htlc type: %T", htlc) } m.pktCond.L.Unlock() diff --git a/htlcswitch/mailbox_test.go b/htlcswitch/mailbox_test.go index 040f2d34..6a7cf026 100644 --- a/htlcswitch/mailbox_test.go +++ b/htlcswitch/mailbox_test.go @@ -38,6 +38,9 @@ func TestMailBoxCouriers(t *testing.T) { outgoingChanID: lnwire.NewShortChanIDFromInt(uint64(prand.Int63())), incomingChanID: lnwire.NewShortChanIDFromInt(uint64(prand.Int63())), amount: lnwire.MilliSatoshi(prand.Int63()), + htlc: &lnwire.UpdateAddHTLC{ + ID: uint64(i), + }, } sentPackets[i] = pkt @@ -315,6 +318,106 @@ func TestMailBoxFailAdd(t *testing.T) { // duplicate fails are sent. go failAdds(adds) ctx.checkFails(nil) + +} + +// TestMailBoxPacketPrioritization asserts that the mailbox will prioritize +// delivering Settle and Fail packets over Adds if both are available for +// delivery at the same time. +func TestMailBoxPacketPrioritization(t *testing.T) { + t.Parallel() + + // First, we'll create new instance of the current default mailbox + // type. + mailBox := newMemoryMailBox(&mailBoxConfig{ + clock: clock.NewDefaultClock(), + expiry: time.Minute, + }) + mailBox.Start() + defer mailBox.Stop() + + const numPackets = 5 + + _, _, aliceChanID, bobChanID := genIDs() + + // Next we'll send the following sequence of packets: + // - Settle1 + // - Add1 + // - Add2 + // - Fail + // - Settle2 + sentPackets := make([]*htlcPacket, numPackets) + for i := 0; i < numPackets; i++ { + pkt := &htlcPacket{ + outgoingChanID: aliceChanID, + outgoingHTLCID: uint64(i), + incomingChanID: bobChanID, + incomingHTLCID: uint64(i), + amount: lnwire.MilliSatoshi(prand.Int63()), + } + + switch i { + case 0, 4: + // First and last packets are a Settle. A non-Add is + // sent first to make the test deterministic w/o needing + // to sleep. + pkt.htlc = &lnwire.UpdateFulfillHTLC{ID: uint64(i)} + case 1, 2: + // Next two packets are Adds. + pkt.htlc = &lnwire.UpdateAddHTLC{ID: uint64(i)} + case 3: + // Last packet is a Fail. + pkt.htlc = &lnwire.UpdateFailHTLC{ID: uint64(i)} + } + + sentPackets[i] = pkt + + err := mailBox.AddPacket(pkt) + if err != nil { + t.Fatalf("failed to add packet: %v", err) + } + } + + // When dequeueing the packets, we expect the following sequence: + // - Settle1 + // - Fail + // - Settle2 + // - Add1 + // - Add2 + // + // We expect to see Fail and Settle2 to be delivered before either Add1 + // or Add2 due to the prioritization between the split queue. + for i := 0; i < numPackets; i++ { + select { + case pkt := <-mailBox.PacketOutBox(): + var expPkt *htlcPacket + switch i { + case 0: + // First packet should be Settle1. + expPkt = sentPackets[0] + case 1: + // Second packet should be Fail. + expPkt = sentPackets[3] + case 2: + // Third packet should be Settle2. + expPkt = sentPackets[4] + case 3: + // Fourth packet should be Add1. + expPkt = sentPackets[1] + case 4: + // Last packet should be Add2. + expPkt = sentPackets[2] + } + + if !reflect.DeepEqual(expPkt, pkt) { + t.Fatalf("recvd packet mismatch %d, want: %v, got: %v", + i, spew.Sdump(expPkt), spew.Sdump(pkt)) + } + + case <-time.After(50 * time.Millisecond): + t.Fatalf("didn't receive packet %d before timeout", i) + } + } } // TestMailOrchestrator asserts that the orchestrator properly buffers packets @@ -346,6 +449,9 @@ func TestMailOrchestrator(t *testing.T) { incomingChanID: bobChanID, incomingHTLCID: uint64(i), amount: lnwire.MilliSatoshi(prand.Int63()), + htlc: &lnwire.UpdateAddHTLC{ + ID: uint64(i), + }, } sentPackets[i] = pkt @@ -411,6 +517,9 @@ func TestMailOrchestrator(t *testing.T) { incomingChanID: bobChanID, incomingHTLCID: uint64(halfPackets + i), amount: lnwire.MilliSatoshi(prand.Int63()), + htlc: &lnwire.UpdateAddHTLC{ + ID: uint64(halfPackets + i), + }, } sentPackets[i] = pkt