From 4f68d1beca0e43bfcdad5dc89ec527060a3e9119 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 4 Jan 2018 09:49:22 -0800 Subject: [PATCH] htlcswitch/switch_test: change forward() -> send() --- htlcswitch/switch_test.go | 1244 ++++++++++++++++++++++++++++++++++--- 1 file changed, 1141 insertions(+), 103 deletions(-) diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index c6e6446e..b8d14aae 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1,8 +1,10 @@ package htlcswitch import ( - "bytes" + "crypto/rand" "crypto/sha256" + "io" + "io/ioutil" "testing" "time" @@ -11,39 +13,41 @@ import ( "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" - "github.com/roasbeef/btcd/chaincfg/chainhash" - "github.com/roasbeef/btcd/wire" "github.com/roasbeef/btcutil" ) -var ( - hash1, _ = chainhash.NewHash(bytes.Repeat([]byte("a"), 32)) - hash2, _ = chainhash.NewHash(bytes.Repeat([]byte("b"), 32)) - - chanPoint1 = wire.NewOutPoint(hash1, 0) - chanPoint2 = wire.NewOutPoint(hash2, 0) - - chanID1 = lnwire.NewChanIDFromOutPoint(chanPoint1) - chanID2 = lnwire.NewChanIDFromOutPoint(chanPoint2) - - aliceChanID = lnwire.NewShortChanIDFromInt(1) - bobChanID = lnwire.NewShortChanIDFromInt(2) -) +func genPreimage() ([32]byte, error) { + var preimage [32]byte + if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { + return preimage, err + } + return preimage, nil +} // TestSwitchForward checks the ability of htlc switch to forward add/settle // requests. func TestSwitchForward(t *testing.T) { t.Parallel() - alicePeer := newMockServer(t, "alice") - bobPeer := newMockServer(t, "bob") + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + bobPeer, err := newMockServer(t, "bob", nil) + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } - s := New(Config{ - FwdingLog: &mockForwardingLog{ - events: make(map[time.Time]channeldb.ForwardingEvent), - }, - }) - s.Start() + s, err := initSwitchWithDB(nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer s.Stop() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() aliceChannelLink := newMockChannelLink( s, chanID1, aliceChanID, alicePeer, true, @@ -60,13 +64,16 @@ func TestSwitchForward(t *testing.T) { // Create request which should be forwarded from Alice channel link to // bob channel link. - preimage := [sha256.Size]byte{1} + preimage, err := genPreimage() + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } rhash := fastsha256.Sum256(preimage[:]) packet := &htlcPacket{ incomingChanID: aliceChannelLink.ShortChanID(), incomingHTLCID: 0, outgoingChanID: bobChannelLink.ShortChanID(), - obfuscator: newMockObfuscator(), + obfuscator: NewMockObfuscator(), htlc: &lnwire.UpdateAddHTLC{ PaymentHash: rhash, Amount: 1, @@ -80,12 +87,14 @@ func TestSwitchForward(t *testing.T) { select { case <-bobChannelLink.packets: - break + if err := bobChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } case <-time.After(time.Second): t.Fatal("request was not propagated to destination") } - if s.circuits.pending() != 1 { + if s.circuits.NumOpen() != 1 { t.Fatal("wrong amount of circuits") } @@ -107,17 +116,953 @@ func TestSwitchForward(t *testing.T) { } select { - case <-aliceChannelLink.packets: - break + case pkt := <-aliceChannelLink.packets: + if err := aliceChannelLink.deleteCircuit(pkt); err != nil { + t.Fatalf("unable to remove circuit: %v", err) + } case <-time.After(time.Second): t.Fatal("request was not propagated to channelPoint") } - if s.circuits.pending() != 0 { + if s.circuits.NumOpen() != 0 { t.Fatal("wrong amount of circuits") } } +func TestSwitchForwardFailAfterFullAdd(t *testing.T) { + t.Parallel() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() + + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + bobPeer, err := newMockServer(t, "bob", nil) + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } + + tempPath, err := ioutil.TempDir("", "circuitdb") + if err != nil { + t.Fatalf("unable to temporary path: %v", err) + } + + cdb, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to open channeldb: %v", err) + } + + s, err := initSwitchWithDB(cdb) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + + // Even though we intend to Stop s later in the test, it is safe to + // defer this Stop since its execution it is protected by an atomic + // guard, guaranteeing it executes at most once. + defer s.Stop() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink := newMockChannelLink( + s, chanID2, bobChanID, bobPeer, true, + ) + if err := s.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + // Create request which should be forwarded from Alice channel link to + // bob channel link. + preimage := [sha256.Size]byte{1} + rhash := fastsha256.Sum256(preimage[:]) + ogPacket := &htlcPacket{ + incomingChanID: aliceChannelLink.ShortChanID(), + incomingHTLCID: 0, + outgoingChanID: bobChannelLink.ShortChanID(), + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + + if s.circuits.NumPending() != 0 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } + + // Handle the request and checks that bob channel link received it. + if err := s.forward(ogPacket); err != nil { + t.Fatal(err) + } + + if s.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } + + // Pull packet from bob's link, but do not perform a full add. + select { + case packet := <-bobChannelLink.packets: + // Complete the payment circuit and assign the outgoing htlc id + // before restarting. + if err := bobChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } + + case <-time.After(time.Second): + t.Fatal("request was not propagated to destination") + } + + if s.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 1 { + t.Fatalf("wrong amount of circuits") + } + + // Now we will restart bob, leaving the forwarding decision for this + // htlc is in the half-added state. + if err := s.Stop(); err != nil { + t.Fatalf(err.Error()) + } + + if err := cdb.Close(); err != nil { + t.Fatalf(err.Error()) + } + + cdb2, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to reopen channeldb: %v", err) + } + + s2, err := initSwitchWithDB(cdb2) + if err != nil { + t.Fatalf("unable reinit switch: %v", err) + } + if err := s2.Start(); err != nil { + t.Fatalf("unable to restart switch: %v", err) + } + + // Even though we intend to Stop s2 later in the test, it is safe to + // defer this Stop since its execution it is protected by an atomic + // guard, guaranteeing it executes at most once. + defer s2.Stop() + + aliceChannelLink = newMockChannelLink( + s2, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink = newMockChannelLink( + s2, chanID2, bobChanID, bobPeer, true, + ) + if err := s2.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s2.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + if s2.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s2.circuits.NumOpen() != 1 { + t.Fatalf("wrong amount of circuits") + } + + // Craft a failure message from the remote peer. + fail := &htlcPacket{ + outgoingChanID: bobChannelLink.ShortChanID(), + outgoingHTLCID: 0, + amount: 1, + htlc: &lnwire.UpdateFailHTLC{}, + } + + // Send the fail packet from the remote peer through the switch. + if err := s2.forward(fail); err != nil { + t.Fatalf(err.Error()) + } + + // Pull packet from alice's link, as it should have gone through + // successfully. + select { + case pkt := <-aliceChannelLink.packets: + if err := aliceChannelLink.completeCircuit(pkt); err != nil { + t.Fatalf("unable to remove circuit: %v", err) + } + case <-time.After(time.Second): + t.Fatal("request was not propagated to destination") + } + + // Circuit map should be empty now. + if s2.circuits.NumPending() != 0 { + t.Fatalf("wrong amount of half circuits") + } + if s2.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } + + // Send the fail packet from the remote peer through the switch. + if err := s2.forward(fail); err == nil { + t.Fatalf("expected failure when sending duplicate fail " + + "with no pending circuit") + } +} + +func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { + t.Parallel() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() + + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + bobPeer, err := newMockServer(t, "bob", nil) + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } + + tempPath, err := ioutil.TempDir("", "circuitdb") + if err != nil { + t.Fatalf("unable to temporary path: %v", err) + } + + cdb, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to open channeldb: %v", err) + } + + s, err := initSwitchWithDB(cdb) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + + // Even though we intend to Stop s later in the test, it is safe to + // defer this Stop since its execution it is protected by an atomic + // guard, guaranteeing it executes at most once. + defer s.Stop() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink := newMockChannelLink( + s, chanID2, bobChanID, bobPeer, true, + ) + if err := s.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + // Create request which should be forwarded from Alice channel link to + // bob channel link. + preimage := [sha256.Size]byte{1} + rhash := fastsha256.Sum256(preimage[:]) + ogPacket := &htlcPacket{ + incomingChanID: aliceChannelLink.ShortChanID(), + incomingHTLCID: 0, + outgoingChanID: bobChannelLink.ShortChanID(), + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + + if s.circuits.NumPending() != 0 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } + + // Handle the request and checks that bob channel link received it. + if err := s.forward(ogPacket); err != nil { + t.Fatal(err) + } + + if s.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } + + // Pull packet from bob's link, but do not perform a full add. + select { + case packet := <-bobChannelLink.packets: + // Complete the payment circuit and assign the outgoing htlc id + // before restarting. + if err := bobChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } + + case <-time.After(time.Second): + t.Fatal("request was not propagated to destination") + } + + if s.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 1 { + t.Fatalf("wrong amount of circuits") + } + + // Now we will restart bob, leaving the forwarding decision for this + // htlc is in the half-added state. + if err := s.Stop(); err != nil { + t.Fatalf(err.Error()) + } + + if err := cdb.Close(); err != nil { + t.Fatalf(err.Error()) + } + + cdb2, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to reopen channeldb: %v", err) + } + + s2, err := initSwitchWithDB(cdb2) + if err != nil { + t.Fatalf("unable reinit switch: %v", err) + } + if err := s2.Start(); err != nil { + t.Fatalf("unable to restart switch: %v", err) + } + + // Even though we intend to Stop s2 later in the test, it is safe to + // defer this Stop since its execution it is protected by an atomic + // guard, guaranteeing it executes at most once. + defer s2.Stop() + + aliceChannelLink = newMockChannelLink( + s2, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink = newMockChannelLink( + s2, chanID2, bobChanID, bobPeer, true, + ) + if err := s2.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s2.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + if s2.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s2.circuits.NumOpen() != 1 { + t.Fatalf("wrong amount of circuits") + } + + // Craft a settle message from the remote peer. + settle := &htlcPacket{ + outgoingChanID: bobChannelLink.ShortChanID(), + outgoingHTLCID: 0, + amount: 1, + htlc: &lnwire.UpdateFulfillHTLC{ + PaymentPreimage: preimage, + }, + } + + // Send the settle packet from the remote peer through the switch. + if err := s2.forward(settle); err != nil { + t.Fatalf(err.Error()) + } + + // Pull packet from alice's link, as it should have gone through + // successfully. + select { + case packet := <-aliceChannelLink.packets: + if err := aliceChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete circuit with in key=%s: %v", + packet.inKey(), err) + } + case <-time.After(time.Second): + t.Fatal("request was not propagated to destination") + } + + // Circuit map should be empty now. + if s2.circuits.NumPending() != 0 { + t.Fatalf("wrong amount of half circuits") + } + if s2.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } + + // Send the settle packet again, which should fail. + if err := s2.forward(settle); err == nil { + t.Fatalf("expected failure when sending duplicate settle " + + "with no pending circuit") + } +} + +func TestSwitchForwardDropAfterFullAdd(t *testing.T) { + t.Parallel() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() + + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + bobPeer, err := newMockServer(t, "bob", nil) + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } + + tempPath, err := ioutil.TempDir("", "circuitdb") + if err != nil { + t.Fatalf("unable to temporary path: %v", err) + } + + cdb, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to open channeldb: %v", err) + } + + s, err := initSwitchWithDB(cdb) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + + // Even though we intend to Stop s later in the test, it is safe to + // defer this Stop since its execution it is protected by an atomic + // guard, guaranteeing it executes at most once. + defer s.Stop() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink := newMockChannelLink( + s, chanID2, bobChanID, bobPeer, true, + ) + if err := s.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + // Create request which should be forwarded from Alice channel link to + // bob channel link. + preimage := [sha256.Size]byte{1} + rhash := fastsha256.Sum256(preimage[:]) + ogPacket := &htlcPacket{ + incomingChanID: aliceChannelLink.ShortChanID(), + incomingHTLCID: 0, + outgoingChanID: bobChannelLink.ShortChanID(), + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + + if s.circuits.NumPending() != 0 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } + + // Handle the request and checks that bob channel link received it. + if err := s.forward(ogPacket); err != nil { + t.Fatal(err) + } + + if s.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of half circuits") + } + + // Pull packet from bob's link, but do not perform a full add. + select { + case packet := <-bobChannelLink.packets: + // Complete the payment circuit and assign the outgoing htlc id + // before restarting. + if err := bobChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } + case <-time.After(time.Second): + t.Fatal("request was not propagated to destination") + } + + // Now we will restart bob, leaving the forwarding decision for this + // htlc is in the half-added state. + if err := s.Stop(); err != nil { + t.Fatalf(err.Error()) + } + + if err := cdb.Close(); err != nil { + t.Fatalf(err.Error()) + } + + cdb2, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to reopen channeldb: %v", err) + } + + s2, err := initSwitchWithDB(cdb2) + if err != nil { + t.Fatalf("unable reinit switch: %v", err) + } + if err := s2.Start(); err != nil { + t.Fatalf("unable to restart switch: %v", err) + } + + // Even though we intend to Stop s2 later in the test, it is safe to + // defer this Stop since its execution it is protected by an atomic + // guard, guaranteeing it executes at most once. + defer s2.Stop() + + aliceChannelLink = newMockChannelLink( + s2, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink = newMockChannelLink( + s2, chanID2, bobChanID, bobPeer, true, + ) + if err := s2.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s2.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + if s2.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s2.circuits.NumOpen() != 1 { + t.Fatalf("wrong amount of half circuits") + } + + // Resend the failed htlc, it should be returned to alice since the + // switch will detect that it has been half added previously. + err = s2.forward(ogPacket) + if err != ErrDuplicateAdd { + t.Fatal("unexpected error when reforwarding a "+ + "failed packet", err) + } + + // After detecting an incomplete forward, the fail packet should have + // been returned to the sender. + select { + case <-aliceChannelLink.packets: + t.Fatal("request should not have returned to source") + case <-bobChannelLink.packets: + t.Fatal("request should not have forwarded to destination") + case <-time.After(time.Second): + } +} + +func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { + t.Parallel() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() + + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + bobPeer, err := newMockServer(t, "bob", nil) + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } + + tempPath, err := ioutil.TempDir("", "circuitdb") + if err != nil { + t.Fatalf("unable to temporary path: %v", err) + } + + cdb, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to open channeldb: %v", err) + } + + s, err := initSwitchWithDB(cdb) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + + // Even though we intend to Stop s later in the test, it is safe to + // defer this Stop since its execution it is protected by an atomic + // guard, guaranteeing it executes at most once. + defer s.Stop() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink := newMockChannelLink( + s, chanID2, bobChanID, bobPeer, true, + ) + if err := s.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + // Create request which should be forwarded from Alice channel link to + // bob channel link. + preimage := [sha256.Size]byte{1} + rhash := fastsha256.Sum256(preimage[:]) + ogPacket := &htlcPacket{ + incomingChanID: aliceChannelLink.ShortChanID(), + incomingHTLCID: 0, + outgoingChanID: bobChannelLink.ShortChanID(), + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + + if s.circuits.NumPending() != 0 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } + + // Handle the request and checks that bob channel link received it. + if err := s.forward(ogPacket); err != nil { + t.Fatal(err) + } + + if s.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of half circuits") + } + + // Pull packet from bob's link, but do not perform a full add. + select { + case <-bobChannelLink.packets: + case <-time.After(time.Second): + t.Fatal("request was not propagated to destination") + } + + // Now we will restart bob, leaving the forwarding decision for this + // htlc is in the half-added state. + if err := s.Stop(); err != nil { + t.Fatalf(err.Error()) + } + + if err := cdb.Close(); err != nil { + t.Fatalf(err.Error()) + } + + cdb2, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to reopen channeldb: %v", err) + } + + s2, err := initSwitchWithDB(cdb2) + if err != nil { + t.Fatalf("unable reinit switch: %v", err) + } + if err := s2.Start(); err != nil { + t.Fatalf("unable to restart switch: %v", err) + } + + // Even though we intend to Stop s2 later in the test, it is safe to + // defer this Stop since its execution it is protected by an atomic + // guard, guaranteeing it executes at most once. + defer s2.Stop() + + aliceChannelLink = newMockChannelLink( + s2, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink = newMockChannelLink( + s2, chanID2, bobChanID, bobPeer, true, + ) + if err := s2.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s2.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + if s2.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s2.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of half circuits") + } + + // Resend the failed htlc, it should be returned to alice since the + // switch will detect that it has been half added previously. + err = s2.forward(ogPacket) + if err != ErrIncompleteForward { + t.Fatal("unexpected error when reforwarding a "+ + "failed packet", err) + } + + // After detecting an incomplete forward, the fail packet should have + // been returned to the sender. + select { + case <-aliceChannelLink.packets: + case <-time.After(time.Second): + t.Fatal("request was not propagated to destination") + } +} + +// TestSwitchForwardCircuitPersistence checks the ability of htlc switch to +// maintain the proper entries in the circuit map in the face of restarts. +func TestSwitchForwardCircuitPersistence(t *testing.T) { + t.Parallel() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() + + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + bobPeer, err := newMockServer(t, "bob", nil) + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } + + tempPath, err := ioutil.TempDir("", "circuitdb") + if err != nil { + t.Fatalf("unable to temporary path: %v", err) + } + + cdb, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to open channeldb: %v", err) + } + + s, err := initSwitchWithDB(cdb) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + + // Even though we intend to Stop s later in the test, it is safe to + // defer this Stop since its execution it is protected by an atomic + // guard, guaranteeing it executes at most once. + defer s.Stop() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink := newMockChannelLink( + s, chanID2, bobChanID, bobPeer, true, + ) + if err := s.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + // Create request which should be forwarded from Alice channel link to + // bob channel link. + preimage := [sha256.Size]byte{1} + rhash := fastsha256.Sum256(preimage[:]) + ogPacket := &htlcPacket{ + incomingChanID: aliceChannelLink.ShortChanID(), + incomingHTLCID: 0, + outgoingChanID: bobChannelLink.ShortChanID(), + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + + if s.circuits.NumPending() != 0 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } + + // Handle the request and checks that bob channel link received it. + if err := s.forward(ogPacket); err != nil { + t.Fatal(err) + } + + if s.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } + + // Retrieve packet from outgoing link and cache until after restart. + var packet *htlcPacket + select { + case packet = <-bobChannelLink.packets: + case <-time.After(time.Second): + t.Fatal("request was not propagated to destination") + } + + if err := s.Stop(); err != nil { + t.Fatalf(err.Error()) + } + + if err := cdb.Close(); err != nil { + t.Fatalf(err.Error()) + } + + cdb2, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to reopen channeldb: %v", err) + } + + s2, err := initSwitchWithDB(cdb2) + if err != nil { + t.Fatalf("unable reinit switch: %v", err) + } + if err := s2.Start(); err != nil { + t.Fatalf("unable to restart switch: %v", err) + } + + // Even though we intend to Stop s2 later in the test, it is safe to + // defer this Stop since its execution it is protected by an atomic + // guard, guaranteeing it executes at most once. + defer s2.Stop() + + aliceChannelLink = newMockChannelLink( + s2, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink = newMockChannelLink( + s2, chanID2, bobChanID, bobPeer, true, + ) + if err := s2.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s2.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + if s2.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s2.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of half circuits") + } + + // Now that the switch has restarted, complete the payment circuit. + if err := bobChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } + + if s2.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s2.circuits.NumOpen() != 1 { + t.Fatal("wrong amount of circuits") + } + + // Create settle request pretending that bob link handled the add htlc + // request and sent the htlc settle request back. This request should + // be forwarder back to Alice link. + ogPacket = &htlcPacket{ + outgoingChanID: bobChannelLink.ShortChanID(), + outgoingHTLCID: 0, + amount: 1, + htlc: &lnwire.UpdateFulfillHTLC{ + PaymentPreimage: preimage, + }, + } + + // Handle the request and checks that payment circuit works properly. + if err := s2.forward(ogPacket); err != nil { + t.Fatal(err) + } + + select { + case packet = <-aliceChannelLink.packets: + if err := aliceChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete circuit with in key=%s: %v", + packet.inKey(), err) + } + case <-time.After(time.Second): + t.Fatal("request was not propagated to channelPoint") + } + + if s2.circuits.NumPending() != 0 { + t.Fatalf("wrong amount of half circuits, want 1, got %d", + s2.circuits.NumPending()) + } + if s2.circuits.NumOpen() != 0 { + t.Fatal("wrong amount of circuits") + } + + if err := s2.Stop(); err != nil { + t.Fatal(err) + } + + if err := cdb2.Close(); err != nil { + t.Fatalf(err.Error()) + } + + cdb3, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to reopen channeldb: %v", err) + } + + s3, err := initSwitchWithDB(cdb3) + if err != nil { + t.Fatalf("unable reinit switch: %v", err) + } + if err := s3.Start(); err != nil { + t.Fatalf("unable to restart switch: %v", err) + } + defer s3.Stop() + + aliceChannelLink = newMockChannelLink( + s3, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink = newMockChannelLink( + s3, chanID2, bobChanID, bobPeer, true, + ) + if err := s3.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s3.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + if s3.circuits.NumPending() != 0 { + t.Fatalf("wrong amount of half circuits") + } + if s3.circuits.NumOpen() != 0 { + t.Fatalf("wrong amount of circuits") + } +} + // TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes // along, then we won't attempt to froward it down al ink that isn't yet able // to forward any HTLC's. @@ -126,15 +1071,25 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { var packet *htlcPacket - alicePeer := newMockServer(t, "alice") - bobPeer := newMockServer(t, "bob") + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + bobPeer, err := newMockServer(t, "bob", nil) + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } - s := New(Config{ - FwdingLog: &mockForwardingLog{ - events: make(map[time.Time]channeldb.ForwardingEvent), - }, - }) - s.Start() + s, err := initSwitchWithDB(nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer s.Stop() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() aliceChannelLink := newMockChannelLink( s, chanID1, aliceChanID, alicePeer, true, @@ -165,16 +1120,16 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { PaymentHash: rhash, Amount: 1, }, - obfuscator: newMockObfuscator(), + obfuscator: NewMockObfuscator(), } // The request to forward should fail as - err := s.forward(packet) + err = s.forward(packet) if err == nil { t.Fatalf("forwarding should have failed due to inactive link") } - if s.circuits.pending() != 0 { + if s.circuits.NumOpen() != 0 { t.Fatal("wrong amount of circuits") } } @@ -186,14 +1141,21 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { // We'll create a single link for this test, marking it as being unable // to forward form the get go. - alicePeer := newMockServer(t, "alice") + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } - s := New(Config{ - FwdingLog: &mockForwardingLog{ - events: make(map[time.Time]channeldb.ForwardingEvent), - }, - }) - s.Start() + s, err := initSwitchWithDB(nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer s.Stop() + + chanID1, _, aliceChanID, _ := genIDs() aliceChannelLink := newMockChannelLink( s, chanID1, aliceChanID, alicePeer, false, @@ -202,7 +1164,10 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { t.Fatalf("unable to add alice link: %v", err) } - preimage := [sha256.Size]byte{1} + preimage, err := genPreimage() + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } rhash := fastsha256.Sum256(preimage[:]) addMsg := &lnwire.UpdateAddHTLC{ PaymentHash: rhash, @@ -213,12 +1178,12 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { // outgoing link. This should fail as Alice isn't yet able to forward // any active HTLC's. alicePub := aliceChannelLink.Peer().PubKey() - _, err := s.SendHTLC(alicePub, addMsg, nil) + _, err = s.SendHTLC(alicePub, addMsg, nil) if err == nil { t.Fatalf("local forward should fail due to inactive link") } - if s.circuits.pending() != 0 { + if s.circuits.NumOpen() != 0 { t.Fatal("wrong amount of circuits") } } @@ -228,15 +1193,25 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { func TestSwitchCancel(t *testing.T) { t.Parallel() - alicePeer := newMockServer(t, "alice") - bobPeer := newMockServer(t, "bob") + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + bobPeer, err := newMockServer(t, "bob", nil) + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } - s := New(Config{ - FwdingLog: &mockForwardingLog{ - events: make(map[time.Time]channeldb.ForwardingEvent), - }, - }) - s.Start() + s, err := initSwitchWithDB(nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer s.Stop() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() aliceChannelLink := newMockChannelLink( s, chanID1, aliceChanID, alicePeer, true, @@ -253,13 +1228,16 @@ func TestSwitchCancel(t *testing.T) { // Create request which should be forwarder from alice channel link // to bob channel link. - preimage := [sha256.Size]byte{1} + preimage, err := genPreimage() + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } rhash := fastsha256.Sum256(preimage[:]) request := &htlcPacket{ incomingChanID: aliceChannelLink.ShortChanID(), incomingHTLCID: 0, outgoingChanID: bobChannelLink.ShortChanID(), - obfuscator: newMockObfuscator(), + obfuscator: NewMockObfuscator(), htlc: &lnwire.UpdateAddHTLC{ PaymentHash: rhash, Amount: 1, @@ -272,13 +1250,19 @@ func TestSwitchCancel(t *testing.T) { } select { - case <-bobChannelLink.packets: - break + case packet := <-bobChannelLink.packets: + if err := bobChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } + case <-time.After(time.Second): t.Fatal("request was not propagated to destination") } - if s.circuits.pending() != 1 { + if s.circuits.NumPending() != 1 { + t.Fatalf("wrong amount of half circuits") + } + if s.circuits.NumOpen() != 1 { t.Fatal("wrong amount of circuits") } @@ -298,13 +1282,19 @@ func TestSwitchCancel(t *testing.T) { } select { - case <-aliceChannelLink.packets: - break + case pkt := <-aliceChannelLink.packets: + if err := aliceChannelLink.completeCircuit(pkt); err != nil { + t.Fatalf("unable to remove circuit: %v", err) + } + case <-time.After(time.Second): t.Fatal("request was not propagated to channelPoint") } - if s.circuits.pending() != 0 { + if s.circuits.NumPending() != 0 { + t.Fatal("wrong amount of circuits") + } + if s.circuits.NumOpen() != 0 { t.Fatal("wrong amount of circuits") } } @@ -314,15 +1304,25 @@ func TestSwitchCancel(t *testing.T) { func TestSwitchAddSamePayment(t *testing.T) { t.Parallel() - alicePeer := newMockServer(t, "alice") - bobPeer := newMockServer(t, "bob") + chanID1, chanID2, aliceChanID, bobChanID := genIDs() - s := New(Config{ - FwdingLog: &mockForwardingLog{ - events: make(map[time.Time]channeldb.ForwardingEvent), - }, - }) - s.Start() + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + bobPeer, err := newMockServer(t, "bob", nil) + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } + + s, err := initSwitchWithDB(nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer s.Stop() aliceChannelLink := newMockChannelLink( s, chanID1, aliceChanID, alicePeer, true, @@ -339,13 +1339,16 @@ func TestSwitchAddSamePayment(t *testing.T) { // Create request which should be forwarder from alice channel link // to bob channel link. - preimage := [sha256.Size]byte{1} + preimage, err := genPreimage() + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } rhash := fastsha256.Sum256(preimage[:]) request := &htlcPacket{ incomingChanID: aliceChannelLink.ShortChanID(), incomingHTLCID: 0, outgoingChanID: bobChannelLink.ShortChanID(), - obfuscator: newMockObfuscator(), + obfuscator: NewMockObfuscator(), htlc: &lnwire.UpdateAddHTLC{ PaymentHash: rhash, Amount: 1, @@ -358,13 +1361,16 @@ func TestSwitchAddSamePayment(t *testing.T) { } select { - case <-bobChannelLink.packets: - break + case packet := <-bobChannelLink.packets: + if err := bobChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } + case <-time.After(time.Second): t.Fatal("request was not propagated to destination") } - if s.circuits.pending() != 1 { + if s.circuits.NumOpen() != 1 { t.Fatal("wrong amount of circuits") } @@ -372,7 +1378,7 @@ func TestSwitchAddSamePayment(t *testing.T) { incomingChanID: aliceChannelLink.ShortChanID(), incomingHTLCID: 1, outgoingChanID: bobChannelLink.ShortChanID(), - obfuscator: newMockObfuscator(), + obfuscator: NewMockObfuscator(), htlc: &lnwire.UpdateAddHTLC{ PaymentHash: rhash, Amount: 1, @@ -384,7 +1390,17 @@ func TestSwitchAddSamePayment(t *testing.T) { t.Fatal(err) } - if s.circuits.pending() != 2 { + select { + case packet := <-bobChannelLink.packets: + if err := bobChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } + + case <-time.After(time.Second): + t.Fatal("request was not propagated to destination") + } + + if s.circuits.NumOpen() != 2 { t.Fatal("wrong amount of circuits") } @@ -404,13 +1420,16 @@ func TestSwitchAddSamePayment(t *testing.T) { } select { - case <-aliceChannelLink.packets: - break + case pkt := <-aliceChannelLink.packets: + if err := aliceChannelLink.completeCircuit(pkt); err != nil { + t.Fatalf("unable to remove circuit: %v", err) + } + case <-time.After(time.Second): t.Fatal("request was not propagated to channelPoint") } - if s.circuits.pending() != 1 { + if s.circuits.NumOpen() != 1 { t.Fatal("wrong amount of circuits") } @@ -427,13 +1446,16 @@ func TestSwitchAddSamePayment(t *testing.T) { } select { - case <-aliceChannelLink.packets: - break + case pkt := <-aliceChannelLink.packets: + if err := aliceChannelLink.completeCircuit(pkt); err != nil { + t.Fatalf("unable to remove circuit: %v", err) + } + case <-time.After(time.Second): t.Fatal("request was not propagated to channelPoint") } - if s.circuits.pending() != 0 { + if s.circuits.NumOpen() != 0 { t.Fatal("wrong amount of circuits") } } @@ -443,14 +1465,21 @@ func TestSwitchAddSamePayment(t *testing.T) { func TestSwitchSendPayment(t *testing.T) { t.Parallel() - alicePeer := newMockServer(t, "alice") + alicePeer, err := newMockServer(t, "alice", nil) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } - s := New(Config{ - FwdingLog: &mockForwardingLog{ - events: make(map[time.Time]channeldb.ForwardingEvent), - }, - }) - s.Start() + s, err := initSwitchWithDB(nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer s.Stop() + + chanID1, _, aliceChanID, _ := genIDs() aliceChannelLink := newMockChannelLink( s, chanID1, aliceChanID, alicePeer, true, @@ -461,7 +1490,10 @@ func TestSwitchSendPayment(t *testing.T) { // Create request which should be forwarder from alice channel link // to bob channel link. - preimage := [sha256.Size]byte{1} + preimage, err := genPreimage() + if err != nil { + t.Fatalf("unable to generate preimage: %v", err) + } rhash := fastsha256.Sum256(preimage[:]) update := &lnwire.UpdateAddHTLC{ PaymentHash: rhash, @@ -485,8 +1517,11 @@ func TestSwitchSendPayment(t *testing.T) { }() select { - case <-aliceChannelLink.packets: - break + case packet := <-aliceChannelLink.packets: + if err := aliceChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } + case err := <-errChan: t.Fatalf("unable to send payment: %v", err) case <-time.After(time.Second): @@ -494,8 +1529,11 @@ func TestSwitchSendPayment(t *testing.T) { } select { - case <-aliceChannelLink.packets: - break + case packet := <-aliceChannelLink.packets: + if err := aliceChannelLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } + case err := <-errChan: t.Fatalf("unable to send payment: %v", err) case <-time.After(time.Second): @@ -506,14 +1544,14 @@ func TestSwitchSendPayment(t *testing.T) { t.Fatal("wrong amount of pending payments") } - if s.circuits.pending() != 2 { + if s.circuits.NumOpen() != 2 { t.Fatal("wrong amount of circuits") } // Create fail request pretending that bob channel link handled // the add htlc request with error and sent the htlc fail request // back. This request should be forwarded back to alice channel link. - obfuscator := newMockObfuscator() + obfuscator := NewMockObfuscator() failure := lnwire.FailIncorrectPaymentAmount{} reason, err := obfuscator.EncryptFirstHop(failure) if err != nil {