diff --git a/htlcswitch/mailbox_test.go b/htlcswitch/mailbox_test.go index 2d04ea69..e8356c97 100644 --- a/htlcswitch/mailbox_test.go +++ b/htlcswitch/mailbox_test.go @@ -147,3 +147,121 @@ func TestMailBoxCouriers(t *testing.T) { spew.Sdump(sentPackets), spew.Sdump(recvdPackets)) } } + +// TestMailOrchestrator asserts that the orchestrator properly buffers packets +// for channels that haven't been made live, such that they are delivered +// immediately after BindLiveShortChanID. It also tests that packets are delivered +// readily to mailboxes for channels that are already in the live state. +func TestMailOrchestrator(t *testing.T) { + t.Parallel() + + // First, we'll create a new instance of our orchestrator. + mo := newMailOrchestrator() + defer mo.Stop() + + // We'll be delivering 10 htlc packets via the orchestrator. + const numPackets = 10 + const halfPackets = numPackets / 2 + + // Before any mailbox is created or made live, we will deliver half of + // the htlcs via the orchestrator. + chanID1, chanID2, aliceChanID, bobChanID := genIDs() + sentPackets := make([]*htlcPacket, halfPackets) + for i := 0; i < halfPackets; i++ { + pkt := &htlcPacket{ + outgoingChanID: aliceChanID, + outgoingHTLCID: uint64(i), + incomingChanID: bobChanID, + incomingHTLCID: uint64(i), + amount: lnwire.MilliSatoshi(prand.Int63()), + } + sentPackets[i] = pkt + + mo.Deliver(pkt.outgoingChanID, pkt) + } + + // Now, initialize a new mailbox for Alice's chanid. + mailbox := mo.GetOrCreateMailBox(chanID1) + + // Verify that no messages are received, since Alice's mailbox has not + // been made live. + for i := 0; i < halfPackets; i++ { + timeout := time.After(50 * time.Millisecond) + select { + case <-mailbox.MessageOutBox(): + t.Fatalf("should not receive wire msg after reset") + case <-timeout: + } + } + + // Assign a short chan id to the existing mailbox, make it available for + // capturing incoming HTLCs. The HTLCs added above should be delivered + // immediately. + mo.BindLiveShortChanID(mailbox, chanID1, aliceChanID) + + // Verify that all of the packets are queued and delivered to Alice's + // mailbox. + recvdPackets := make([]*htlcPacket, 0, len(sentPackets)) + for i := 0; i < halfPackets; i++ { + timeout := time.After(5 * time.Second) + select { + case <-timeout: + t.Fatalf("didn't recv pkt %d after timeout", i) + case pkt := <-mailbox.PacketOutBox(): + recvdPackets = append(recvdPackets, pkt) + } + } + + // We should have received half of the total number of packets. + if len(recvdPackets) != halfPackets { + t.Fatalf("expected %v packets instead got %v", + halfPackets, len(recvdPackets)) + } + + // Check that the received packets are equal to the sent packets. + if !reflect.DeepEqual(recvdPackets, sentPackets) { + t.Fatalf("recvd packets mismatched: expected %v, got %v", + spew.Sdump(sentPackets), spew.Sdump(recvdPackets)) + } + + // For the second half of the test, create a new mailbox for Bob and + // immediately make it live with an assigned short chan id. + mailbox = mo.GetOrCreateMailBox(chanID2) + mo.BindLiveShortChanID(mailbox, chanID2, bobChanID) + + // Create the second half of our htlcs, and deliver them via the + // orchestrator. We should be able to receive each of these in order. + recvdPackets = make([]*htlcPacket, 0, len(sentPackets)) + for i := 0; i < halfPackets; i++ { + pkt := &htlcPacket{ + outgoingChanID: aliceChanID, + outgoingHTLCID: uint64(halfPackets + i), + incomingChanID: bobChanID, + incomingHTLCID: uint64(halfPackets + i), + amount: lnwire.MilliSatoshi(prand.Int63()), + } + sentPackets[i] = pkt + + mo.Deliver(pkt.incomingChanID, pkt) + + timeout := time.After(50 * time.Millisecond) + select { + case <-timeout: + t.Fatalf("didn't recv pkt %d after timeout", halfPackets+i) + case pkt := <-mailbox.PacketOutBox(): + recvdPackets = append(recvdPackets, pkt) + } + } + + // Again, we should have received half of the total number of packets. + if len(recvdPackets) != halfPackets { + t.Fatalf("expected %v packets instead got %v", + halfPackets, len(recvdPackets)) + } + + // Check that the received packets are equal to the sent packets. + if !reflect.DeepEqual(recvdPackets, sentPackets) { + t.Fatalf("recvd packets mismatched: expected %v, got %v", + spew.Sdump(sentPackets), spew.Sdump(recvdPackets)) + } +}