diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 6581b97f..4444ec31 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -390,6 +390,8 @@ func (s *mockServer) String() string { } type mockChannelLink struct { + htlcSwitch *Switch + shortChanID lnwire.ShortChannelID chanID lnwire.ChannelID @@ -399,12 +401,16 @@ type mockChannelLink struct { packets chan *htlcPacket eligible bool + + htlcID uint64 } -func newMockChannelLink(chanID lnwire.ChannelID, shortChanID lnwire.ShortChannelID, - peer Peer, eligible bool) *mockChannelLink { +func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID, + shortChanID lnwire.ShortChannelID, peer Peer, eligible bool, +) *mockChannelLink { return &mockChannelLink{ + htlcSwitch: htlcSwitch, chanID: chanID, shortChanID: shortChanID, packets: make(chan *htlcPacket, 1), @@ -414,6 +420,19 @@ func newMockChannelLink(chanID lnwire.ChannelID, shortChanID lnwire.ShortChannel } func (f *mockChannelLink) HandleSwitchPacket(packet *htlcPacket) { + switch htlc := packet.htlc.(type) { + case *lnwire.UpdateAddHTLC: + f.htlcSwitch.addCircuit(&PaymentCircuit{ + PaymentHash: htlc.PaymentHash, + IncomingChanID: packet.incomingChanID, + IncomingHTLCID: packet.incomingHTLCID, + OutgoingChanID: f.shortChanID, + OutgoingHTLCID: f.htlcID, + ErrorEncrypter: packet.obfuscator, + }) + f.htlcID++ + } + f.packets <- packet } diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 1aee7e92..7f6eb1f4 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -35,15 +35,15 @@ func TestSwitchForward(t *testing.T) { alicePeer := newMockServer(t, "alice") bobPeer := newMockServer(t, "bob") - aliceChannelLink := newMockChannelLink( - chanID1, aliceChanID, alicePeer, true, - ) - bobChannelLink := newMockChannelLink( - chanID2, bobChanID, bobPeer, true, - ) - s := New(Config{}) s.Start() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink := newMockChannelLink( + s, chanID2, bobChanID, bobPeer, true, + ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } @@ -71,15 +71,6 @@ func TestSwitchForward(t *testing.T) { t.Fatal(err) } - s.addCircuit(&PaymentCircuit{ - PaymentHash: rhash, - IncomingChanID: packet.incomingChanID, - IncomingHTLCID: packet.incomingHTLCID, - OutgoingChanID: packet.outgoingChanID, - OutgoingHTLCID: 0, - ErrorEncrypter: packet.obfuscator, - }) - select { case <-bobChannelLink.packets: break @@ -131,18 +122,19 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { alicePeer := newMockServer(t, "alice") bobPeer := newMockServer(t, "bob") + s := New(Config{}) + s.Start() + aliceChannelLink := newMockChannelLink( - chanID1, aliceChanID, alicePeer, true, + s, chanID1, aliceChanID, alicePeer, true, ) // We'll create a link for Bob, but mark the link as unable to forward // any new outgoing HTLC's. bobChannelLink := newMockChannelLink( - chanID2, bobChanID, bobPeer, false, + s, chanID2, bobChanID, bobPeer, false, ) - s := New(Config{}) - s.Start() if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } @@ -184,12 +176,13 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { // We'll create a single link for this test, marking it as being unable // to forward form the get go. alicePeer := newMockServer(t, "alice") - aliceChannelLink := newMockChannelLink( - chanID1, aliceChanID, alicePeer, false, - ) s := New(Config{}) s.Start() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, false, + ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } @@ -223,15 +216,15 @@ func TestSwitchCancel(t *testing.T) { alicePeer := newMockServer(t, "alice") bobPeer := newMockServer(t, "bob") - aliceChannelLink := newMockChannelLink( - chanID1, aliceChanID, alicePeer, true, - ) - bobChannelLink := newMockChannelLink( - chanID2, bobChanID, bobPeer, true, - ) - s := New(Config{}) s.Start() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink := newMockChannelLink( + s, chanID2, bobChanID, bobPeer, true, + ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } @@ -259,15 +252,6 @@ func TestSwitchCancel(t *testing.T) { t.Fatal(err) } - s.addCircuit(&PaymentCircuit{ - PaymentHash: rhash, - IncomingChanID: request.incomingChanID, - IncomingHTLCID: request.incomingHTLCID, - OutgoingChanID: request.outgoingChanID, - OutgoingHTLCID: 0, - ErrorEncrypter: request.obfuscator, - }) - select { case <-bobChannelLink.packets: break @@ -315,15 +299,15 @@ func TestSwitchAddSamePayment(t *testing.T) { alicePeer := newMockServer(t, "alice") bobPeer := newMockServer(t, "bob") - aliceChannelLink := newMockChannelLink( - chanID1, aliceChanID, alicePeer, true, - ) - bobChannelLink := newMockChannelLink( - chanID2, bobChanID, bobPeer, true, - ) - s := New(Config{}) s.Start() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink := newMockChannelLink( + s, chanID2, bobChanID, bobPeer, true, + ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add alice link: %v", err) } @@ -351,15 +335,6 @@ func TestSwitchAddSamePayment(t *testing.T) { t.Fatal(err) } - s.addCircuit(&PaymentCircuit{ - PaymentHash: rhash, - IncomingChanID: request.incomingChanID, - IncomingHTLCID: request.incomingHTLCID, - OutgoingChanID: request.outgoingChanID, - OutgoingHTLCID: 0, - ErrorEncrypter: request.obfuscator, - }) - select { case <-bobChannelLink.packets: break @@ -387,15 +362,6 @@ func TestSwitchAddSamePayment(t *testing.T) { t.Fatal(err) } - s.addCircuit(&PaymentCircuit{ - PaymentHash: rhash, - IncomingChanID: request.incomingChanID, - IncomingHTLCID: request.incomingHTLCID, - OutgoingChanID: request.outgoingChanID, - OutgoingHTLCID: 1, - ErrorEncrypter: request.obfuscator, - }) - if s.circuits.pending() != 2 { t.Fatal("wrong amount of circuits") } @@ -458,12 +424,13 @@ func TestSwitchSendPayment(t *testing.T) { t.Parallel() alicePeer := newMockServer(t, "alice") - aliceChannelLink := newMockChannelLink( - chanID1, aliceChanID, alicePeer, true, - ) s := New(Config{}) s.Start() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) if err := s.AddLink(aliceChannelLink); err != nil { t.Fatalf("unable to add link: %v", err) } @@ -515,7 +482,7 @@ func TestSwitchSendPayment(t *testing.T) { t.Fatal("wrong amount of pending payments") } - if s.circuits.pending() != 0 { + if s.circuits.pending() != 2 { t.Fatal("wrong amount of circuits") } @@ -536,7 +503,6 @@ func TestSwitchSendPayment(t *testing.T) { isObfuscated: true, htlc: &lnwire.UpdateFailHTLC{ Reason: reason, - ID: 1, }, } @@ -553,6 +519,15 @@ func TestSwitchSendPayment(t *testing.T) { t.Fatal("err wasn't received") } + packet = &htlcPacket{ + outgoingChanID: aliceChannelLink.ShortChanID(), + outgoingHTLCID: 1, + isObfuscated: true, + htlc: &lnwire.UpdateFailHTLC{ + Reason: reason, + }, + } + // Send second failure response and check that user were able to // receive the error. if err := s.forward(packet); err != nil {