From cb85095ab009e0f4f7a922e36f258fb40ae02f8a Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Wed, 9 Oct 2019 16:37:25 +0200 Subject: [PATCH] htlcswitch/test: assert on replied failure message --- htlcswitch/mock.go | 3 +++ htlcswitch/switch_test.go | 8 +++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 12067040..19894300 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -334,6 +334,7 @@ var _ hop.Iterator = (*mockHopIterator)(nil) // encodes the failure and do not makes any onion obfuscation. type mockObfuscator struct { ogPacket *sphinx.OnionPacket + failure lnwire.FailureMessage } // NewMockObfuscator initializes a dummy mockObfuscator used for testing. @@ -366,6 +367,8 @@ func (o *mockObfuscator) Reextract( func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( lnwire.OpaqueReason, error) { + o.failure = failure + var b bytes.Buffer if err := lnwire.EncodeFailure(&b, failure, 0); err != nil { return nil, err diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index b0a53eb3..bd7e09fd 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1336,6 +1336,7 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { // Alice. preimage := [sha256.Size]byte{1} rhash := fastsha256.Sum256(preimage[:]) + obfuscator := NewMockObfuscator() packet = &htlcPacket{ incomingChanID: aliceChannelLink.ShortChanID(), incomingHTLCID: 0, @@ -1344,7 +1345,7 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { PaymentHash: rhash, Amount: 1, }, - obfuscator: NewMockObfuscator(), + obfuscator: obfuscator, } // The request to forward should fail as @@ -1353,6 +1354,11 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { t.Fatalf("forwarding should have failed due to inactive link") } + failure := obfuscator.(*mockObfuscator).failure + if _, ok := failure.(*lnwire.FailTemporaryChannelFailure); !ok { + t.Fatalf("unexpected failure %T", failure) + } + if s.circuits.NumOpen() != 0 { t.Fatal("wrong amount of circuits") }