diff --git a/channeldb/addr_test.go b/channeldb/addr_test.go index c4bc4e8e..c3d4e07f 100644 --- a/channeldb/addr_test.go +++ b/channeldb/addr_test.go @@ -3,6 +3,7 @@ package channeldb import ( "bytes" "net" + "strings" "testing" "github.com/lightningnetwork/lnd/tor" @@ -15,7 +16,7 @@ func (t unknownAddrType) String() string { return "unknown" } var addrTests = []struct { expAddr net.Addr - serErr error + serErr string }{ { expAddr: &net.TCPAddr{ @@ -43,7 +44,7 @@ var addrTests = []struct { }, { expAddr: unknownAddrType{}, - serErr: ErrUnknownAddressType, + serErr: ErrUnknownAddressType.Error(), }, } @@ -55,11 +56,21 @@ func TestAddrSerialization(t *testing.T) { var b bytes.Buffer for _, test := range addrTests { err := serializeAddr(&b, test.expAddr) - if err != test.serErr { + switch { + case err == nil && test.serErr != "": + t.Fatalf("expected serialization err for addr %v", + test.expAddr) + + case err != nil && test.serErr == "": + t.Fatalf("unexpected serialization err for addr %v: %v", + test.expAddr, err) + + case err != nil && !strings.Contains(err.Error(), test.serErr): t.Fatalf("unexpected serialization err for addr %v, "+ - "want: %v, got %v", - test.expAddr, test.serErr, err) - } else if test.serErr != nil { + "want: %v, got %v", test.expAddr, test.serErr, + err) + + case err != nil: continue }