diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 4252b721..97ab7d5b 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -248,10 +248,10 @@ func TestOpenChannelPutGetDelete(t *testing.T) { t.Parallel() cdb, cleanUp, err := makeTestDB() - defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } + defer cleanUp() // Create the test channel state, then add an additional fake HTLC // before syncing to disk. @@ -368,10 +368,10 @@ func TestChannelStateTransition(t *testing.T) { t.Parallel() cdb, cleanUp, err := makeTestDB() - defer cleanUp() if err != nil { t.Fatalf("unable to make test database: %v", err) } + defer cleanUp() // First create a minimal channel, then perform a full sync in order to // persist the data. diff --git a/channeldb/db_test.go b/channeldb/db_test.go index 872e6307..794b1fcf 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -5,6 +5,9 @@ import ( "os" "path/filepath" "testing" + + "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/lnwire" ) func TestOpenWithCreate(t *testing.T) { @@ -71,3 +74,76 @@ func TestWipe(t *testing.T) { ErrNoClosedChannels, err) } } + +// TestFetchClosedChannelForID tests that we are able to properly retrieve a +// ChannelCloseSummary from the DB given a ChannelID. +func TestFetchClosedChannelForID(t *testing.T) { + t.Parallel() + + const numChans = 101 + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // Create the test channel state, that we will mutate the index of the + // funding point. + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + // Now run through the number of channels, and modify the outpoint index + // to create new channel IDs. + for i := uint32(0); i < numChans; i++ { + // Save the open channel to disk. + state.FundingOutpoint.Index = i + if err := state.FullSync(); err != nil { + t.Fatalf("unable to save and serialize channel "+ + "state: %v", err) + } + + // Close the channel. To make sure we retrieve the correct + // summary later, we make them differ in the SettledBalance. + closeSummary := &ChannelCloseSummary{ + ChanPoint: state.FundingOutpoint, + RemotePub: state.IdentityPub, + SettledBalance: btcutil.Amount(500 + i), + } + if err := state.CloseChannel(closeSummary); err != nil { + t.Fatalf("unable to close channel: %v", err) + } + } + + // Now run though them all again and make sure we are able to retrieve + // summaries from the DB. + for i := uint32(0); i < numChans; i++ { + state.FundingOutpoint.Index = i + + // We calculate the ChannelID and use it to fetch the summary. + cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint) + fetchedSummary, err := cdb.FetchClosedChannelForID(cid) + if err != nil { + t.Fatalf("unable to fetch close summary: %v", err) + } + + // Make sure we retrieved the correct one by checking the + // SettledBalance. + if fetchedSummary.SettledBalance != btcutil.Amount(500+i) { + t.Fatalf("summaries don't match: expected %v got %v", + btcutil.Amount(500+i), + fetchedSummary.SettledBalance) + } + } + + // As a final test we make sure that we get ErrClosedChannelNotFound + // for a ChannelID we didn't add to the DB. + state.FundingOutpoint.Index++ + cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint) + _, err = cdb.FetchClosedChannelForID(cid) + if err != ErrClosedChannelNotFound { + t.Fatalf("expected ErrClosedChannelNotFound, instead got: %v", err) + } +}