diff --git a/channeldb/channel.go b/channeldb/channel.go index 21468b7f..e2e6658d 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -711,6 +711,12 @@ func (c *OpenChannel) HasChanStatus(status ChannelStatus) bool { } func (c *OpenChannel) hasChanStatus(status ChannelStatus) bool { + // Special case ChanStatusDefualt since it isn't actually flag, but a + // particular combination (or lack-there-of) of flags. + if status == ChanStatusDefault { + return c.chanStatus == ChanStatusDefault + } + return c.chanStatus&status == status } diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index eb068485..6b1e0ab8 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -1581,3 +1581,62 @@ func TestBalanceAtHeight(t *testing.T) { }) } } + +// TestHasChanStatus asserts the behavior of HasChanStatus by checking the +// behavior of various status flags in addition to the special case of +// ChanStatusDefault which is treated like a flag in the code base even though +// it isn't. +func TestHasChanStatus(t *testing.T) { + tests := []struct { + name string + status ChannelStatus + expHas map[ChannelStatus]bool + }{ + { + name: "default", + status: ChanStatusDefault, + expHas: map[ChannelStatus]bool{ + ChanStatusDefault: true, + ChanStatusBorked: false, + }, + }, + { + name: "single flag", + status: ChanStatusBorked, + expHas: map[ChannelStatus]bool{ + ChanStatusDefault: false, + ChanStatusBorked: true, + }, + }, + { + name: "multiple flags", + status: ChanStatusBorked | ChanStatusLocalDataLoss, + expHas: map[ChannelStatus]bool{ + ChanStatusDefault: false, + ChanStatusBorked: true, + ChanStatusLocalDataLoss: true, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + c := &OpenChannel{ + chanStatus: test.status, + } + + for status, expHas := range test.expHas { + has := c.HasChanStatus(status) + if has == expHas { + continue + } + + t.Fatalf("expected chan status to "+ + "have %s? %t, got: %t", + status, expHas, has) + } + }) + } +}