diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index fc34422e..10322dbd 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -11,7 +11,6 @@ import ( "net" "reflect" "runtime" - "strings" "sync" "testing" "time" @@ -623,12 +622,8 @@ func TestExitNodeAmountPayloadMismatch(t *testing.T) { ).Wait(30 * time.Second) if err == nil { t.Fatalf("payment should have failed but didn't") - } else if !strings.Contains(err.Error(), lnwire.CodeUnknownPaymentHash.String()) { - // TODO(roasbeef): use proper error after error propagation is - // in - t.Fatalf("expected %v got %v", err, - lnwire.CodeUnknownPaymentHash) } + assertFailureCode(t, err, lnwire.CodeUnknownPaymentHash) } // TestLinkForwardTimelockPolicyMismatch tests that if a node is an @@ -1025,9 +1020,8 @@ func TestChannelLinkMultiHopInsufficientPayment(t *testing.T) { ).Wait(30 * time.Second) if err == nil { t.Fatal("error haven't been received") - } else if !strings.Contains(err.Error(), "insufficient capacity") { - t.Fatalf("wrong error has been received: %v", err) } + assertFailureCode(t, err, lnwire.CodeTemporaryChannelFailure) // Wait for Alice to receive the revocation. // @@ -1136,10 +1130,7 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) { t.Fatalf("no result arrive") } - fErr := result.Error - if !strings.Contains(fErr.Error(), lnwire.CodeUnknownPaymentHash.String()) { - t.Fatalf("expected %v got %v", lnwire.CodeUnknownPaymentHash, fErr) - } + assertFailureCode(t, result.Error, lnwire.CodeUnknownPaymentHash) // Wait for Alice to receive the revocation. time.Sleep(100 * time.Millisecond) @@ -5846,11 +5837,7 @@ func TestChannelLinkHoldInvoiceCancel(t *testing.T) { // Wait for payment to succeed. err = <-ctx.errChan - if !strings.Contains(err.Error(), - lnwire.CodeUnknownPaymentHash.String()) { - - t.Fatal("expected unknown payment hash") - } + assertFailureCode(t, err, lnwire.CodeUnknownPaymentHash) } // TestChannelLinkHoldInvoiceRestart asserts hodl htlcs are held after blocks @@ -6081,3 +6068,17 @@ func TestChannelLinkRevocationWindowHodl(t *testing.T) { default: } } + +// assertFailureCode asserts that an error is of type ForwardingError and that +// the failure code is as expected. +func assertFailureCode(t *testing.T, err error, code lnwire.FailCode) { + fErr, ok := err.(*ForwardingError) + if !ok { + t.Fatalf("expected ForwardingError but got %T", err) + } + + if fErr.FailureMessage.Code() != code { + t.Fatalf("expected %v but got %v", + code, fErr.FailureMessage.Code()) + } +} diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 13f6acc3..7ed01dea 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1824,14 +1824,8 @@ func TestSwitchSendPayment(t *testing.T) { select { case err := <-errChan: - fErr, ok := err.(*ForwardingError) - if !ok { - t.Fatal("expected ForwardingError") - } + assertFailureCode(t, err, lnwire.CodeUnknownPaymentHash) - if _, ok := fErr.FailureMessage.(*lnwire.FailUnknownPaymentHash); !ok { - t.Fatalf("expected UnknownPaymentHash got %v", fErr) - } case <-time.After(time.Second): t.Fatal("err wasn't received") }