switch/test: use external interface for testing

Previously the forward(...) method was used in forwarding tests,
while that code path isn't used for forwards in reality.
This commit is contained in:
Joost Jager 2020-04-13 14:23:36 +02:00
parent 341308327e
commit babb0a36b4
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7

@ -213,7 +213,7 @@ func TestSwitchSendPending(t *testing.T) {
// Send the ADD packet, this should not be forwarded out to the link // Send the ADD packet, this should not be forwarded out to the link
// since there are no eligible links. // since there are no eligible links.
err = s.forward(packet) err = forwardPackets(t, s, packet)
linkErr, ok := err.(*LinkError) linkErr, ok := err.(*LinkError)
if !ok { if !ok {
t.Fatalf("expected link error, got: %T", err) t.Fatalf("expected link error, got: %T", err)
@ -249,7 +249,7 @@ func TestSwitchSendPending(t *testing.T) {
packet.incomingHTLCID++ packet.incomingHTLCID++
// Handle the request and checks that bob channel link received it. // Handle the request and checks that bob channel link received it.
if err := s.forward(packet); err != nil { if err := forwardPackets(t, s, packet); err != nil {
t.Fatalf("unexpected forward failure: %v", err) t.Fatalf("unexpected forward failure: %v", err)
} }
@ -322,7 +322,7 @@ func TestSwitchForward(t *testing.T) {
} }
// Handle the request and checks that bob channel link received it. // Handle the request and checks that bob channel link received it.
if err := s.forward(packet); err != nil { if err := forwardPackets(t, s, packet); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -356,7 +356,7 @@ func TestSwitchForward(t *testing.T) {
} }
// Handle the request and checks that payment circuit works properly. // Handle the request and checks that payment circuit works properly.
if err := s.forward(packet); err != nil { if err := forwardPackets(t, s, packet); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -451,7 +451,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
} }
// Handle the request and checks that bob channel link received it. // Handle the request and checks that bob channel link received it.
if err := s.forward(ogPacket); err != nil { if err := forwardPackets(t, s, ogPacket); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -539,7 +539,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
} }
// Send the fail packet from the remote peer through the switch. // Send the fail packet from the remote peer through the switch.
if err := s2.forward(fail); err != nil { if err := <-s2.ForwardPackets(nil, fail); err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
@ -563,7 +563,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
} }
// Send the fail packet from the remote peer through the switch. // Send the fail packet from the remote peer through the switch.
if err := s2.forward(fail); err == nil { if err := <-s2.ForwardPackets(nil, fail); err == nil {
t.Fatalf("expected failure when sending duplicate fail " + t.Fatalf("expected failure when sending duplicate fail " +
"with no pending circuit") "with no pending circuit")
} }
@ -646,7 +646,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
} }
// Handle the request and checks that bob channel link received it. // Handle the request and checks that bob channel link received it.
if err := s.forward(ogPacket); err != nil { if err := forwardPackets(t, s, ogPacket); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -736,7 +736,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
} }
// Send the settle packet from the remote peer through the switch. // Send the settle packet from the remote peer through the switch.
if err := s2.forward(settle); err != nil { if err := <-s2.ForwardPackets(nil, settle); err != nil {
t.Fatalf(err.Error()) t.Fatalf(err.Error())
} }
@ -761,7 +761,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
} }
// Send the settle packet again, which should fail. // Send the settle packet again, which should fail.
if err := s2.forward(settle); err != nil { if err := <-s2.ForwardPackets(nil, settle); err != nil {
t.Fatalf("expected success when sending duplicate settle " + t.Fatalf("expected success when sending duplicate settle " +
"with no pending circuit") "with no pending circuit")
} }
@ -844,7 +844,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
} }
// Handle the request and checks that bob channel link received it. // Handle the request and checks that bob channel link received it.
if err := s.forward(ogPacket); err != nil { if err := forwardPackets(t, s, ogPacket); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -915,12 +915,10 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
t.Fatalf("wrong amount of half circuits") t.Fatalf("wrong amount of half circuits")
} }
// Resend the failed htlc, it should be returned to alice since the // Resend the failed htlc. The packet will be dropped silently since the
// switch will detect that it has been half added previously. // switch will detect that it has been half added previously.
err = s2.forward(ogPacket) if err := <-s2.ForwardPackets(nil, ogPacket); err != nil {
if err != ErrDuplicateAdd { t.Fatal(err)
t.Fatal("unexpected error when reforwarding a "+
"failed packet", err)
} }
// After detecting an incomplete forward, the fail packet should have // After detecting an incomplete forward, the fail packet should have
@ -1011,7 +1009,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
} }
// Handle the request and checks that bob channel link received it. // Handle the request and checks that bob channel link received it.
if err := s.forward(ogPacket); err != nil { if err := forwardPackets(t, s, ogPacket); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1079,20 +1077,20 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
// Resend the failed htlc, it should be returned to alice since the // Resend the failed htlc, it should be returned to alice since the
// switch will detect that it has been half added previously. // switch will detect that it has been half added previously.
err = s2.forward(ogPacket) err = <-s2.ForwardPackets(nil, ogPacket)
linkErr, ok := err.(*LinkError) if err != nil {
if !ok { t.Fatal(err)
t.Fatalf("expected link error, got: %T", err)
}
if linkErr.FailureDetail != OutgoingFailureIncompleteForward {
t.Fatalf("expected incomplete forward, got: %v",
linkErr.FailureDetail)
} }
// After detecting an incomplete forward, the fail packet should have // After detecting an incomplete forward, the fail packet should have
// been returned to the sender. // been returned to the sender.
select { select {
case <-aliceChannelLink.packets: case pkt := <-aliceChannelLink.packets:
linkErr := pkt.linkFailure
if linkErr.FailureDetail != OutgoingFailureIncompleteForward {
t.Fatalf("expected incomplete forward, got: %v",
linkErr.FailureDetail)
}
case <-time.After(time.Second): case <-time.After(time.Second):
t.Fatal("request was not propagated to destination") t.Fatal("request was not propagated to destination")
} }
@ -1177,7 +1175,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) {
} }
// Handle the request and checks that bob channel link received it. // Handle the request and checks that bob channel link received it.
if err := s.forward(ogPacket); err != nil { if err := forwardPackets(t, s, ogPacket); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1267,7 +1265,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) {
} }
// Handle the request and checks that payment circuit works properly. // Handle the request and checks that payment circuit works properly.
if err := s2.forward(ogPacket); err != nil { if err := <-s2.ForwardPackets(nil, ogPacket); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1417,7 +1415,7 @@ func TestCircularForwards(t *testing.T) {
// Attempt to forward the packet and check for the expected // Attempt to forward the packet and check for the expected
// error. // error.
err = s.forward(packet) err = forwardPackets(t, s, packet)
if !reflect.DeepEqual(err, test.expectedErr) { if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("expected: %v, got: %v", t.Fatalf("expected: %v, got: %v",
test.expectedErr, err) test.expectedErr, err)
@ -1637,7 +1635,7 @@ func testSkipIneligibleLinksMultiHopForward(t *testing.T,
} }
// The request to forward should fail as // The request to forward should fail as
err = s.forward(packet) err = forwardPackets(t, s, packet)
failure := obfuscator.(*mockObfuscator).failure failure := obfuscator.(*mockObfuscator).failure
if testCase.expectedReply == lnwire.CodeNone { if testCase.expectedReply == lnwire.CodeNone {
@ -1796,7 +1794,7 @@ func TestSwitchCancel(t *testing.T) {
} }
// Handle the request and checks that bob channel link received it. // Handle the request and checks that bob channel link received it.
if err := s.forward(request); err != nil { if err := forwardPackets(t, s, request); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1828,7 +1826,7 @@ func TestSwitchCancel(t *testing.T) {
} }
// Handle the request and checks that payment circuit works properly. // Handle the request and checks that payment circuit works properly.
if err := s.forward(request); err != nil { if err := forwardPackets(t, s, request); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1911,7 +1909,7 @@ func TestSwitchAddSamePayment(t *testing.T) {
} }
// Handle the request and checks that bob channel link received it. // Handle the request and checks that bob channel link received it.
if err := s.forward(request); err != nil { if err := forwardPackets(t, s, request); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1941,7 +1939,7 @@ func TestSwitchAddSamePayment(t *testing.T) {
} }
// Handle the request and checks that bob channel link received it. // Handle the request and checks that bob channel link received it.
if err := s.forward(request); err != nil { if err := forwardPackets(t, s, request); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1970,7 +1968,7 @@ func TestSwitchAddSamePayment(t *testing.T) {
} }
// Handle the request and checks that payment circuit works properly. // Handle the request and checks that payment circuit works properly.
if err := s.forward(request); err != nil { if err := forwardPackets(t, s, request); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1996,7 +1994,7 @@ func TestSwitchAddSamePayment(t *testing.T) {
} }
// Handle the request and checks that payment circuit works properly. // Handle the request and checks that payment circuit works properly.
if err := s.forward(request); err != nil { if err := forwardPackets(t, s, request); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -2139,7 +2137,7 @@ func TestSwitchSendPayment(t *testing.T) {
}, },
} }
if err := s.forward(packet); err != nil { if err := forwardPackets(t, s, packet); err != nil {
t.Fatalf("can't forward htlc packet: %v", err) t.Fatalf("can't forward htlc packet: %v", err)
} }
@ -2634,7 +2632,7 @@ func TestInvalidFailure(t *testing.T) {
}, },
} }
if err := s.forward(packet); err != nil { if err := forwardPackets(t, s, packet); err != nil {
t.Fatalf("can't forward htlc packet: %v", err) t.Fatalf("can't forward htlc packet: %v", err)
} }
@ -3060,3 +3058,17 @@ func getThreeHopEvents(channels *clusterChannels, htlcID uint64,
return aliceEvents, bobEvents, carolEvents return aliceEvents, bobEvents, carolEvents
} }
// forwardPackets forwards packets to the switch and enforces a timeout on the
// reply.
func forwardPackets(t *testing.T, s *Switch, packets ...*htlcPacket) error {
select {
case err := <-s.ForwardPackets(nil, packets...):
return err
case <-time.After(time.Second):
t.Fatal("no timely reply from switch")
return nil
}
}