diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index d33a413e..c396e580 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -397,7 +397,12 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) { t.Fatalf("payment should have failed but didn't") } - switch err.(type) { + ferr, ok := err.(*ForwardingError) + if !ok { + t.Fatalf("expected a ForwardingError, instead got: %T", err) + } + + switch ferr.FailureMessage.(type) { case *lnwire.FailFinalIncorrectCltvExpiry: default: t.Fatalf("incorrect error, expected incorrect cltv expiry, "+ @@ -481,7 +486,12 @@ func TestLinkForwardTimelockPolicyMismatch(t *testing.T) { t.Fatalf("payment should have failed but didn't") } - switch err.(type) { + ferr, ok := err.(*ForwardingError) + if !ok { + t.Fatalf("expected a ForwardingError, instead got: %T", err) + } + + switch ferr.FailureMessage.(type) { case *lnwire.FailIncorrectCltvExpiry: default: t.Fatalf("incorrect error, expected incorrect cltv expiry, "+ @@ -527,8 +537,12 @@ func TestLinkForwardFeePolicyMismatch(t *testing.T) { t.Fatalf("payment should have failed but didn't") } - switch err.(type) { - // TODO(roasbeef): assert get proper fee back + ferr, ok := err.(*ForwardingError) + if !ok { + t.Fatalf("expected a ForwardingError, instead got: %T", err) + } + + switch ferr.FailureMessage.(type) { case *lnwire.FailFeeInsufficient: default: t.Fatalf("incorrect error, expected fee insufficient, "+ @@ -574,7 +588,12 @@ func TestLinkForwardMinHTLCPolicyMismatch(t *testing.T) { t.Fatalf("payment should have failed but didn't") } - switch err.(type) { + ferr, ok := err.(*ForwardingError) + if !ok { + t.Fatalf("expected a ForwardingError, instead got: %T", err) + } + + switch ferr.FailureMessage.(type) { case *lnwire.FailAmountBelowMinimum: default: t.Fatalf("incorrect error, expected amount below minimum, "+ @@ -896,7 +915,7 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) { // Replace decode function with another which throws an error. n.carolChannelLink.cfg.DecodeOnionObfuscator = func( - r io.Reader) (Obfuscator, lnwire.FailCode) { + r io.Reader) (ErrorEncrypter, lnwire.FailCode) { return nil, lnwire.CodeInvalidOnionVersion } @@ -914,7 +933,13 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) { if err == nil { t.Fatal("error haven't been received") } - switch err.(type) { + + ferr, ok := err.(*ForwardingError) + if !ok { + t.Fatalf("expected a ForwardingError, instead got: %T", err) + } + + switch ferr.FailureMessage.(type) { case *lnwire.FailInvalidOnionVersion: default: t.Fatalf("wrong error have been received: %v", err) @@ -987,7 +1012,12 @@ func TestChannelLinkExpiryTooSoonExitNode(t *testing.T) { "time lock value") } - switch err.(type) { + ferr, ok := err.(*ForwardingError) + if !ok { + t.Fatalf("expected a ForwardingError, instead got: %T", err) + } + + switch ferr.FailureMessage.(type) { case *lnwire.FailFinalIncorrectCltvExpiry: default: t.Fatalf("incorrect error, expected final time lock too "+ @@ -1033,7 +1063,12 @@ func TestChannelLinkExpiryTooSoonMidNode(t *testing.T) { "time lock value") } - switch err.(type) { + ferr, ok := err.(*ForwardingError) + if !ok { + t.Fatalf("expected a ForwardingError, instead got: %T", err) + } + + switch ferr.FailureMessage.(type) { case *lnwire.FailExpiryTooSoon: default: t.Fatalf("incorrect error, expected final time lock too "+ @@ -1203,7 +1238,7 @@ func newSingleLinkTestHarness(chanAmt btcutil.Amount) (ChannelLink, func(), erro Peer: &alicePeer, Switch: nil, DecodeHopIterator: decoder.DecodeHopIterator, - DecodeOnionObfuscator: func(io.Reader) (Obfuscator, lnwire.FailCode) { + DecodeOnionObfuscator: func(io.Reader) (ErrorEncrypter, lnwire.FailCode) { return obfuscator, lnwire.CodeNone }, GetLastChannelUpdate: mockGetChanUpdateMessage,