From babb0a36b4f479009fab6a89d003154e163fda7a Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Mon, 13 Apr 2020 14:23:36 +0200 Subject: [PATCH] switch/test: use external interface for testing Previously the forward(...) method was used in forwarding tests, while that code path isn't used for forwards in reality. --- htlcswitch/switch_test.go | 88 ++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 38 deletions(-) diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index de905961..60604c42 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -213,7 +213,7 @@ func TestSwitchSendPending(t *testing.T) { // Send the ADD packet, this should not be forwarded out to the link // since there are no eligible links. - err = s.forward(packet) + err = forwardPackets(t, s, packet) linkErr, ok := err.(*LinkError) if !ok { t.Fatalf("expected link error, got: %T", err) @@ -249,7 +249,7 @@ func TestSwitchSendPending(t *testing.T) { packet.incomingHTLCID++ // Handle the request and checks that bob channel link received it. - if err := s.forward(packet); err != nil { + if err := forwardPackets(t, s, packet); err != nil { t.Fatalf("unexpected forward failure: %v", err) } @@ -322,7 +322,7 @@ func TestSwitchForward(t *testing.T) { } // Handle the request and checks that bob channel link received it. - if err := s.forward(packet); err != nil { + if err := forwardPackets(t, s, packet); err != nil { t.Fatal(err) } @@ -356,7 +356,7 @@ func TestSwitchForward(t *testing.T) { } // Handle the request and checks that payment circuit works properly. - if err := s.forward(packet); err != nil { + if err := forwardPackets(t, s, packet); err != nil { t.Fatal(err) } @@ -451,7 +451,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { } // Handle the request and checks that bob channel link received it. - if err := s.forward(ogPacket); err != nil { + if err := forwardPackets(t, s, ogPacket); err != nil { t.Fatal(err) } @@ -539,7 +539,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { } // Send the fail packet from the remote peer through the switch. - if err := s2.forward(fail); err != nil { + if err := <-s2.ForwardPackets(nil, fail); err != nil { t.Fatalf(err.Error()) } @@ -563,7 +563,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { } // Send the fail packet from the remote peer through the switch. - if err := s2.forward(fail); err == nil { + if err := <-s2.ForwardPackets(nil, fail); err == nil { t.Fatalf("expected failure when sending duplicate fail " + "with no pending circuit") } @@ -646,7 +646,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { } // Handle the request and checks that bob channel link received it. - if err := s.forward(ogPacket); err != nil { + if err := forwardPackets(t, s, ogPacket); err != nil { t.Fatal(err) } @@ -736,7 +736,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { } // Send the settle packet from the remote peer through the switch. - if err := s2.forward(settle); err != nil { + if err := <-s2.ForwardPackets(nil, settle); err != nil { t.Fatalf(err.Error()) } @@ -761,7 +761,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { } // Send the settle packet again, which should fail. - if err := s2.forward(settle); err != nil { + if err := <-s2.ForwardPackets(nil, settle); err != nil { t.Fatalf("expected success when sending duplicate settle " + "with no pending circuit") } @@ -844,7 +844,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { } // Handle the request and checks that bob channel link received it. - if err := s.forward(ogPacket); err != nil { + if err := forwardPackets(t, s, ogPacket); err != nil { t.Fatal(err) } @@ -915,12 +915,10 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { t.Fatalf("wrong amount of half circuits") } - // Resend the failed htlc, it should be returned to alice since the + // Resend the failed htlc. The packet will be dropped silently since the // switch will detect that it has been half added previously. - err = s2.forward(ogPacket) - if err != ErrDuplicateAdd { - t.Fatal("unexpected error when reforwarding a "+ - "failed packet", err) + if err := <-s2.ForwardPackets(nil, ogPacket); err != nil { + t.Fatal(err) } // After detecting an incomplete forward, the fail packet should have @@ -1011,7 +1009,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { } // Handle the request and checks that bob channel link received it. - if err := s.forward(ogPacket); err != nil { + if err := forwardPackets(t, s, ogPacket); err != nil { t.Fatal(err) } @@ -1079,20 +1077,20 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { // Resend the failed htlc, it should be returned to alice since the // switch will detect that it has been half added previously. - err = s2.forward(ogPacket) - linkErr, ok := err.(*LinkError) - if !ok { - t.Fatalf("expected link error, got: %T", err) - } - if linkErr.FailureDetail != OutgoingFailureIncompleteForward { - t.Fatalf("expected incomplete forward, got: %v", - linkErr.FailureDetail) + err = <-s2.ForwardPackets(nil, ogPacket) + if err != nil { + t.Fatal(err) } // After detecting an incomplete forward, the fail packet should have // been returned to the sender. select { - case <-aliceChannelLink.packets: + case pkt := <-aliceChannelLink.packets: + linkErr := pkt.linkFailure + if linkErr.FailureDetail != OutgoingFailureIncompleteForward { + t.Fatalf("expected incomplete forward, got: %v", + linkErr.FailureDetail) + } case <-time.After(time.Second): t.Fatal("request was not propagated to destination") } @@ -1177,7 +1175,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { } // Handle the request and checks that bob channel link received it. - if err := s.forward(ogPacket); err != nil { + if err := forwardPackets(t, s, ogPacket); err != nil { t.Fatal(err) } @@ -1267,7 +1265,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { } // Handle the request and checks that payment circuit works properly. - if err := s2.forward(ogPacket); err != nil { + if err := <-s2.ForwardPackets(nil, ogPacket); err != nil { t.Fatal(err) } @@ -1417,7 +1415,7 @@ func TestCircularForwards(t *testing.T) { // Attempt to forward the packet and check for the expected // error. - err = s.forward(packet) + err = forwardPackets(t, s, packet) if !reflect.DeepEqual(err, test.expectedErr) { t.Fatalf("expected: %v, got: %v", test.expectedErr, err) @@ -1637,7 +1635,7 @@ func testSkipIneligibleLinksMultiHopForward(t *testing.T, } // The request to forward should fail as - err = s.forward(packet) + err = forwardPackets(t, s, packet) failure := obfuscator.(*mockObfuscator).failure if testCase.expectedReply == lnwire.CodeNone { @@ -1796,7 +1794,7 @@ func TestSwitchCancel(t *testing.T) { } // Handle the request and checks that bob channel link received it. - if err := s.forward(request); err != nil { + if err := forwardPackets(t, s, request); err != nil { t.Fatal(err) } @@ -1828,7 +1826,7 @@ func TestSwitchCancel(t *testing.T) { } // Handle the request and checks that payment circuit works properly. - if err := s.forward(request); err != nil { + if err := forwardPackets(t, s, request); err != nil { t.Fatal(err) } @@ -1911,7 +1909,7 @@ func TestSwitchAddSamePayment(t *testing.T) { } // Handle the request and checks that bob channel link received it. - if err := s.forward(request); err != nil { + if err := forwardPackets(t, s, request); err != nil { t.Fatal(err) } @@ -1941,7 +1939,7 @@ func TestSwitchAddSamePayment(t *testing.T) { } // Handle the request and checks that bob channel link received it. - if err := s.forward(request); err != nil { + if err := forwardPackets(t, s, request); err != nil { t.Fatal(err) } @@ -1970,7 +1968,7 @@ func TestSwitchAddSamePayment(t *testing.T) { } // Handle the request and checks that payment circuit works properly. - if err := s.forward(request); err != nil { + if err := forwardPackets(t, s, request); err != nil { t.Fatal(err) } @@ -1996,7 +1994,7 @@ func TestSwitchAddSamePayment(t *testing.T) { } // Handle the request and checks that payment circuit works properly. - if err := s.forward(request); err != nil { + if err := forwardPackets(t, s, request); err != nil { t.Fatal(err) } @@ -2139,7 +2137,7 @@ func TestSwitchSendPayment(t *testing.T) { }, } - if err := s.forward(packet); err != nil { + if err := forwardPackets(t, s, packet); err != nil { t.Fatalf("can't forward htlc packet: %v", err) } @@ -2634,7 +2632,7 @@ func TestInvalidFailure(t *testing.T) { }, } - if err := s.forward(packet); err != nil { + if err := forwardPackets(t, s, packet); err != nil { t.Fatalf("can't forward htlc packet: %v", err) } @@ -3060,3 +3058,17 @@ func getThreeHopEvents(channels *clusterChannels, htlcID uint64, return aliceEvents, bobEvents, carolEvents } + +// forwardPackets forwards packets to the switch and enforces a timeout on the +// reply. +func forwardPackets(t *testing.T, s *Switch, packets ...*htlcPacket) error { + + select { + case err := <-s.ForwardPackets(nil, packets...): + return err + + case <-time.After(time.Second): + t.Fatal("no timely reply from switch") + return nil + } +}