diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 927c380c..cb29b521 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -187,6 +187,13 @@ func fundingPointOption(chanPoint wire.OutPoint) testChannelOption { } } +// channelIDOption is an option which sets the short channel ID of the channel. +var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { + return func(params *testChannelParams) { + params.channel.ShortChannelID = chanID + } +} + // createTestChannel writes a test channel to the database. It takes a set of // functional options which can be used to overwrite the default of creating // a pending channel that was broadcast at height 100. diff --git a/channeldb/db.go b/channeldb/db.go index c9410bfb..696f3f00 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -556,42 +556,28 @@ func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) { // within the database, including pending open, fully open and channels waiting // for a closing transaction to confirm. func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { - var channels []*OpenChannel - - // TODO(halseth): fetch all in one db tx. - openChannels, err := d.FetchAllOpenChannels() - if err != nil { - return nil, err - } - channels = append(channels, openChannels...) - - pendingChannels, err := d.FetchPendingChannels() - if err != nil { - return nil, err - } - channels = append(channels, pendingChannels...) - - waitingClose, err := d.FetchWaitingCloseChannels() - if err != nil { - return nil, err - } - channels = append(channels, waitingClose...) - - return channels, nil + return fetchChannels(d) } // FetchAllOpenChannels will return all channels that have the funding // transaction confirmed, and is not waiting for a closing transaction to be // confirmed. func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { - return fetchChannels(d, false, false) + return fetchChannels( + d, + pendingChannelFilter(false), + waitingCloseFilter(false), + ) } // FetchPendingChannels will return channels that have completed the process of // generating and broadcasting funding transactions, but whose funding // transactions have yet to be confirmed on the blockchain. func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { - return fetchChannels(d, true, false) + return fetchChannels(d, + pendingChannelFilter(true), + waitingCloseFilter(false), + ) } // FetchWaitingCloseChannels will return all channels that have been opened, @@ -599,25 +585,49 @@ func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { // // NOTE: This includes channels that are also pending to be opened. func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { - waitingClose, err := fetchChannels(d, false, true) - if err != nil { - return nil, err - } - pendingWaitingClose, err := fetchChannels(d, true, true) - if err != nil { - return nil, err - } + return fetchChannels( + d, waitingCloseFilter(true), + ) +} - return append(waitingClose, pendingWaitingClose...), nil +// fetchChannelsFilter applies a filter to channels retrieved in fetchchannels. +// A set of filters can be combined to filter across multiple dimensions. +type fetchChannelsFilter func(channel *OpenChannel) bool + +// pendingChannelFilter returns a filter based on whether channels are pending +// (ie, their funding transaction still needs to confirm). If pending is false, +// channels with confirmed funding transactions are returned. +func pendingChannelFilter(pending bool) fetchChannelsFilter { + return func(channel *OpenChannel) bool { + return channel.IsPending == pending + } +} + +// waitingCloseFilter returns a filter which filters channels based on whether +// they are awaiting the confirmation of their closing transaction. If waiting +// close is true, channels that have had their closing tx broadcast are +// included. If it is false, channels that are not awaiting confirmation of +// their close transaction are returned. +func waitingCloseFilter(waitingClose bool) fetchChannelsFilter { + return func(channel *OpenChannel) bool { + // If the channel is in any other state than Default, + // then it means it is waiting to be closed. + channelWaitingClose := + channel.ChanStatus() != ChanStatusDefault + + // Include the channel if it matches the value for + // waiting close that we are filtering on. + return channelWaitingClose == waitingClose + } } // fetchChannels attempts to retrieve channels currently stored in the -// database. The pending parameter determines whether only pending channels -// will be returned, or only open channels will be returned. The waitingClose -// parameter determines whether only channels waiting for a closing transaction -// to be confirmed should be returned. If no active channels exist within the -// network, then ErrNoActiveChannels is returned. -func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) { +// database. It takes a set of filters which are applied to each channel to +// obtain a set of channels with the desired set of properties. Only channels +// which have a true value returned for *all* of the filters will be returned. +// If no filters are provided, every channel in the open channels bucket will +// be returned. +func fetchChannels(d *DB, filters ...fetchChannelsFilter) ([]*OpenChannel, error) { var channels []*OpenChannel err := d.View(func(tx *bbolt.Tx) error { @@ -667,24 +677,27 @@ func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) { "node_key=%x: %v", chainHash[:], k, err) } for _, channel := range nodeChans { - if channel.IsPending != pending { - continue + // includeChannel indicates whether the channel + // meets the criteria specified by our filters. + includeChannel := true + + // Run through each filter and check whether the + // channel should be included. + for _, f := range filters { + // If the channel fails the filter, set + // includeChannel to false and don't bother + // checking the remaining filters. + if !f(channel) { + includeChannel = false + break + } } - // If the channel is in any other state - // than Default, then it means it is - // waiting to be closed. - channelWaitingClose := - channel.ChanStatus() != ChanStatusDefault - - // Only include it if we requested - // channels with the same waitingClose - // status. - if channelWaitingClose != waitingClose { - continue + // If the channel passed every filter, include it in + // our set of channels. + if includeChannel { + channels = append(channels, channel) } - - channels = append(channels, channel) } return nil }) diff --git a/channeldb/db_test.go b/channeldb/db_test.go index c1f52507..935a287c 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -500,3 +500,196 @@ func TestAbandonChannel(t *testing.T) { t.Fatalf("unable to abandon channel: %v", err) } } + +// TestFetchChannels tests the filtering of open channels in fetchChannels. +// It tests the case where no filters are provided (which is equivalent to +// FetchAllOpenChannels) and every combination of pending and waiting close. +func TestFetchChannels(t *testing.T) { + // Create static channel IDs for each kind of channel retrieved by + // fetchChannels so that the expected channel IDs can be set in tests. + var ( + // Pending is a channel that is pending open, and has not had + // a close initiated. + pendingChan = lnwire.NewShortChanIDFromInt(1) + + // pendingWaitingClose is a channel that is pending open and + // has has its closing transaction broadcast. + pendingWaitingChan = lnwire.NewShortChanIDFromInt(2) + + // openChan is a channel that has confirmed on chain. + openChan = lnwire.NewShortChanIDFromInt(3) + + // openWaitingChan is a channel that has confirmed on chain, + // and it waiting for its close transaction to confirm. + openWaitingChan = lnwire.NewShortChanIDFromInt(4) + ) + + tests := []struct { + name string + filters []fetchChannelsFilter + expectedChannels map[lnwire.ShortChannelID]bool + }{ + { + name: "get all channels", + filters: []fetchChannelsFilter{}, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingChan: true, + pendingWaitingChan: true, + openChan: true, + openWaitingChan: true, + }, + }, + { + name: "pending channels", + filters: []fetchChannelsFilter{ + pendingChannelFilter(true), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingChan: true, + pendingWaitingChan: true, + }, + }, + { + name: "open channels", + filters: []fetchChannelsFilter{ + pendingChannelFilter(false), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + openChan: true, + openWaitingChan: true, + }, + }, + { + name: "waiting close channels", + filters: []fetchChannelsFilter{ + waitingCloseFilter(true), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingWaitingChan: true, + openWaitingChan: true, + }, + }, + { + name: "not waiting close channels", + filters: []fetchChannelsFilter{ + waitingCloseFilter(false), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingChan: true, + openChan: true, + }, + }, + { + name: "pending waiting", + filters: []fetchChannelsFilter{ + pendingChannelFilter(true), + waitingCloseFilter(true), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingWaitingChan: true, + }, + }, + { + name: "pending, not waiting", + filters: []fetchChannelsFilter{ + pendingChannelFilter(true), + waitingCloseFilter(false), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + pendingChan: true, + }, + }, + { + name: "open waiting", + filters: []fetchChannelsFilter{ + pendingChannelFilter(false), + waitingCloseFilter(true), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + openWaitingChan: true, + }, + }, + { + name: "open, not waiting", + filters: []fetchChannelsFilter{ + pendingChannelFilter(false), + waitingCloseFilter(false), + }, + expectedChannels: map[lnwire.ShortChannelID]bool{ + openChan: true, + }, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test "+ + "database: %v", err) + } + defer cleanUp() + + // Create a pending channel that is not awaiting close. + createTestChannel( + t, cdb, channelIDOption(pendingChan), + ) + + // Create a pending channel which has has been marked as + // broadcast, indicating that its closing transaction is + // waiting to confirm. + pendingClosing := createTestChannel( + t, cdb, + channelIDOption(pendingWaitingChan), + ) + + err = pendingClosing.MarkCoopBroadcasted(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Create a open channel that is not awaiting close. + createTestChannel( + t, cdb, + channelIDOption(openChan), + openChannelOption(), + ) + + // Create a open channel which has has been marked as + // broadcast, indicating that its closing transaction is + // waiting to confirm. + openClosing := createTestChannel( + t, cdb, + channelIDOption(openWaitingChan), + openChannelOption(), + ) + err = openClosing.MarkCoopBroadcasted(nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + channels, err := fetchChannels(cdb, test.filters...) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(channels) != len(test.expectedChannels) { + t.Fatalf("expected: %v channels, "+ + "got: %v", len(test.expectedChannels), + len(channels)) + } + + for _, ch := range channels { + _, ok := test.expectedChannels[ch.ShortChannelID] + if !ok { + t.Fatalf("fetch channels unexpected "+ + "channel: %v", ch.ShortChannelID) + } + } + }) + } +}