From e5ead599ccedfab4c57f79323890ab640587130e Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Mon, 8 Apr 2019 11:29:18 +0200 Subject: [PATCH] htlcswitch/test: use single channel restore function This commit refactors test code around channel restoration in unit tests to make it easier to use. --- htlcswitch/link_test.go | 40 +++++++------ htlcswitch/test_utils.go | 120 ++++++++++++++++++++++++--------------- 2 files changed, 98 insertions(+), 62 deletions(-) diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 7aaf1da3..2ef764b7 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -180,7 +180,7 @@ func TestChannelLinkSingleHopPayment(t *testing.T) { t.Parallel() // Setup a alice-bob network. - aliceChannel, bobChannel, cleanUp, err := createTwoClusterChannels( + alice, bob, cleanUp, err := createTwoClusterChannels( btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5) if err != nil { @@ -188,7 +188,9 @@ func TestChannelLinkSingleHopPayment(t *testing.T) { } defer cleanUp() - n := newTwoHopNetwork(t, aliceChannel, bobChannel, testStartingHeight) + n := newTwoHopNetwork( + t, alice.channel, bob.channel, testStartingHeight, + ) if err := n.start(); err != nil { t.Fatal(err) } @@ -1592,7 +1594,7 @@ func (m *mockPeer) Address() net.Addr { func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( ChannelLink, *lnwallet.LightningChannel, chan time.Time, func() error, - func(), chanRestoreFunc, error) { + func(), func() (*lnwallet.LightningChannel, error), error) { var chanIDBytes [8]byte if _, err := io.ReadFull(rand.Reader, chanIDBytes[:]); err != nil { @@ -1602,7 +1604,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( chanID := lnwire.NewShortChanIDFromInt( binary.BigEndian.Uint64(chanIDBytes[:])) - aliceChannel, bobChannel, fCleanUp, restore, err := createTestChannel( + aliceLc, bobLc, fCleanUp, err := createTestChannel( alicePrivKey, bobPrivKey, chanAmt, chanAmt, chanReserve, chanReserve, chanID, ) @@ -1628,7 +1630,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( pCache := newMockPreimageCache() - aliceDb := aliceChannel.State().Db + aliceDb := aliceLc.channel.State().Db aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb) if err != nil { return nil, nil, nil, nil, nil, nil, err @@ -1668,7 +1670,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( } const startingHeight = 100 - aliceLink := NewChannelLink(aliceCfg, aliceChannel) + aliceLink := NewChannelLink(aliceCfg, aliceLc.channel) start := func() error { return aliceSwitch.AddLink(aliceLink) } @@ -1687,7 +1689,8 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( defer fCleanUp() } - return aliceLink, bobChannel, bticker.Force, start, cleanUp, restore, nil + return aliceLink, bobLc.channel, bticker.Force, start, cleanUp, + aliceLc.restore, nil } func assertLinkBandwidth(t *testing.T, link ChannelLink, @@ -2546,7 +2549,9 @@ func TestChannelLinkTrimCircuitsPending(t *testing.T) { t.Fatalf("unable to start test harness: %v", err) } - alice := newPersistentLinkHarness(t, aliceLink, batchTicker, restore) + alice := newPersistentLinkHarness( + t, aliceLink, batchTicker, restore, + ) // Compute the static fees that will be used to determine the // correctness of Alice's bandwidth when forwarding HTLCs. @@ -2818,7 +2823,9 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) { t.Fatalf("unable to start test harness: %v", err) } - alice := newPersistentLinkHarness(t, aliceLink, batchTicker, restore) + alice := newPersistentLinkHarness( + t, aliceLink, batchTicker, restore, + ) // We'll put Alice into hodl.Commit mode, such that the circuits for any // outgoing ADDs are opened, but the changes are not committed in the @@ -3980,14 +3987,15 @@ type persistentLinkHarness struct { batchTicker chan time.Time msgs chan lnwire.Message - restoreChan chanRestoreFunc + restoreChan func() (*lnwallet.LightningChannel, error) } // newPersistentLinkHarness initializes a new persistentLinkHarness and derives // the supporting references from the active link. func newPersistentLinkHarness(t *testing.T, link ChannelLink, batchTicker chan time.Time, - restore chanRestoreFunc) *persistentLinkHarness { + restore func() (*lnwallet.LightningChannel, + error)) *persistentLinkHarness { coreLink := link.(*channelLink) @@ -4034,7 +4042,7 @@ func (h *persistentLinkHarness) restart(restartSwitch bool, // state, we will restore the persisted state to ensure we always start // the link in a consistent state. var err error - h.channel, _, err = h.restoreChan() + h.channel, err = h.restoreChan() if err != nil { h.t.Fatalf("unable to restore channels: %v", err) } @@ -5573,7 +5581,7 @@ func TestChannelLinkCanceledInvoice(t *testing.T) { t.Parallel() // Setup a alice-bob network. - aliceChannel, bobChannel, cleanUp, err := createTwoClusterChannels( + alice, bob, cleanUp, err := createTwoClusterChannels( btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5) if err != nil { @@ -5581,7 +5589,7 @@ func TestChannelLinkCanceledInvoice(t *testing.T) { } defer cleanUp() - n := newTwoHopNetwork(t, aliceChannel, bobChannel, testStartingHeight) + n := newTwoHopNetwork(t, alice.channel, bob.channel, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } @@ -5638,7 +5646,7 @@ type hodlInvoiceTestCtx struct { func newHodlInvoiceTestCtx(t *testing.T) (*hodlInvoiceTestCtx, error) { // Setup a alice-bob network. - aliceChannel, bobChannel, cleanUp, err := createTwoClusterChannels( + alice, bob, cleanUp, err := createTwoClusterChannels( btcutil.SatoshiPerBitcoin*3, btcutil.SatoshiPerBitcoin*5, ) @@ -5646,7 +5654,7 @@ func newHodlInvoiceTestCtx(t *testing.T) (*hodlInvoiceTestCtx, error) { t.Fatalf("unable to create channel: %v", err) } - n := newTwoHopNetwork(t, aliceChannel, bobChannel, testStartingHeight) + n := newTwoHopNetwork(t, alice.channel, bob.channel, testStartingHeight) if err := n.start(); err != nil { t.Fatal(err) } diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index d749eaa2..f2557761 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -147,15 +147,19 @@ func generateRandomBytes(n int) ([]byte, error) { return b, nil } +type testLightningChannel struct { + channel *lnwallet.LightningChannel + restore func() (*lnwallet.LightningChannel, error) +} + // createTestChannel creates the channel and returns our and remote channels // representations. // // TODO(roasbeef): need to factor out, similar func re-used in many parts of codebase func createTestChannel(alicePrivKey, bobPrivKey []byte, aliceAmount, bobAmount, aliceReserve, bobReserve btcutil.Amount, - chanID lnwire.ShortChannelID) (*lnwallet.LightningChannel, *lnwallet.LightningChannel, func(), - func() (*lnwallet.LightningChannel, *lnwallet.LightningChannel, - error), error) { + chanID lnwire.ShortChannelID) (*testLightningChannel, + *testLightningChannel, func(), error) { aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), alicePrivKey) bobKeyPriv, bobKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), bobPrivKey) @@ -187,7 +191,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, var hash [sha256.Size]byte randomSeed, err := generateRandomBytes(sha256.Size) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } copy(hash[:], randomSeed) @@ -236,23 +240,23 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize()) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot) bobFirstRevoke, err := bobPreimageProducer.AtIndex(0) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:]) aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize()) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot) aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:]) @@ -260,25 +264,25 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, bobAmount, &aliceCfg, &bobCfg, aliceCommitPoint, bobCommitPoint, *fundingTxIn) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } alicePath, err := ioutil.TempDir("", "alicedb") dbAlice, err := channeldb.Open(alicePath) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } bobPath, err := ioutil.TempDir("", "bobdb") dbBob, err := channeldb.Open(bobPath) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } estimator := lnwallet.NewStaticFeeEstimator(6000, 0) feePerKw, err := estimator.EstimateFeePerKW(1) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } commitFee := feePerKw.FeeForWeight(724) @@ -350,11 +354,11 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, } if err := aliceChannelState.SyncPending(bobAddr, broadcastHeight); err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } if err := bobChannelState.SyncPending(aliceAddr, broadcastHeight); err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } cleanUpFunc := func() { @@ -372,7 +376,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, aliceSigner, aliceChannelState, alicePool, ) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } alicePool.Start() @@ -381,7 +385,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, bobSigner, bobChannelState, bobPool, ) if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } bobPool.Start() @@ -389,40 +393,38 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, // having Alice and Bob extend their revocation windows to each other. aliceNextRevoke, err := channelAlice.NextRevocationKey() if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } if err := channelBob.InitNextRevocation(aliceNextRevoke); err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } bobNextRevoke, err := channelBob.NextRevocationKey() if err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } if err := channelAlice.InitNextRevocation(bobNextRevoke); err != nil { - return nil, nil, nil, nil, err + return nil, nil, nil, err } - restore := func() (*lnwallet.LightningChannel, *lnwallet.LightningChannel, - error) { - + restoreAlice := func() (*lnwallet.LightningChannel, error) { aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub) switch err { case nil: case bbolt.ErrDatabaseNotOpen: dbAlice, err = channeldb.Open(dbAlice.Path()) if err != nil { - return nil, nil, errors.Errorf("unable to reopen alice "+ + return nil, errors.Errorf("unable to reopen alice "+ "db: %v", err) } aliceStoredChannels, err = dbAlice.FetchOpenChannels(aliceKeyPub) if err != nil { - return nil, nil, errors.Errorf("unable to fetch alice "+ + return nil, errors.Errorf("unable to fetch alice "+ "channel: %v", err) } default: - return nil, nil, errors.Errorf("unable to fetch alice channel: "+ + return nil, errors.Errorf("unable to fetch alice channel: "+ "%v", err) } @@ -435,34 +437,38 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, } if aliceStoredChannel == nil { - return nil, nil, errors.New("unable to find stored alice channel") + return nil, errors.New("unable to find stored alice channel") } newAliceChannel, err := lnwallet.NewLightningChannel( aliceSigner, aliceStoredChannel, alicePool, ) if err != nil { - return nil, nil, errors.Errorf("unable to create new channel: %v", + return nil, errors.Errorf("unable to create new channel: %v", err) } + return newAliceChannel, nil + } + + restoreBob := func() (*lnwallet.LightningChannel, error) { bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub) switch err { case nil: case bbolt.ErrDatabaseNotOpen: dbBob, err = channeldb.Open(dbBob.Path()) if err != nil { - return nil, nil, errors.Errorf("unable to reopen bob "+ + return nil, errors.Errorf("unable to reopen bob "+ "db: %v", err) } bobStoredChannels, err = dbBob.FetchOpenChannels(bobKeyPub) if err != nil { - return nil, nil, errors.Errorf("unable to fetch bob "+ + return nil, errors.Errorf("unable to fetch bob "+ "channel: %v", err) } default: - return nil, nil, errors.Errorf("unable to fetch bob channel: "+ + return nil, errors.Errorf("unable to fetch bob channel: "+ "%v", err) } @@ -475,20 +481,31 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte, } if bobStoredChannel == nil { - return nil, nil, errors.New("unable to find stored bob channel") + return nil, errors.New("unable to find stored bob channel") } newBobChannel, err := lnwallet.NewLightningChannel( bobSigner, bobStoredChannel, bobPool, ) if err != nil { - return nil, nil, errors.Errorf("unable to create new channel: %v", + return nil, errors.Errorf("unable to create new channel: %v", err) } - return newAliceChannel, newBobChannel, nil + return newBobChannel, nil } - return channelAlice, channelBob, cleanUpFunc, restore, nil + testLightningChannelAlice := &testLightningChannel{ + channel: channelAlice, + restore: restoreAlice, + } + + testLightningChannelBob := &testLightningChannel{ + channel: channelBob, + restore: restoreBob, + } + + return testLightningChannelAlice, testLightningChannelBob, cleanUpFunc, + nil } // getChanID retrieves the channel point from an lnnwire message. @@ -825,7 +842,7 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) ( _, _, firstChanID, secondChanID := genIDs() // Create lightning channels between Alice<->Bob and Bob<->Carol - aliceChannel, firstBobChannel, cleanAliceBob, restoreAliceBob, err := + aliceChannel, firstBobChannel, cleanAliceBob, err := createTestChannel(alicePrivKey, bobPrivKey, aliceToBob, aliceToBob, 0, 0, firstChanID) if err != nil { @@ -833,7 +850,7 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) ( "alice<->bob channel: %v", err) } - secondBobChannel, carolChannel, cleanBobCarol, restoreBobCarol, err := + secondBobChannel, carolChannel, cleanBobCarol, err := createTestChannel(bobPrivKey, carolPrivKey, bobToCarol, bobToCarol, 0, 0, secondChanID) if err != nil { @@ -848,12 +865,23 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) ( } restoreFromDb := func() (*clusterChannels, error) { - a2b, b2a, err := restoreAliceBob() + + a2b, err := aliceChannel.restore() if err != nil { return nil, err } - b2c, c2b, err := restoreBobCarol() + b2a, err := firstBobChannel.restore() + if err != nil { + return nil, err + } + + b2c, err := secondBobChannel.restore() + if err != nil { + return nil, err + } + + c2b, err := carolChannel.restore() if err != nil { return nil, err } @@ -867,10 +895,10 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) ( } return &clusterChannels{ - aliceToBob: aliceChannel, - bobToAlice: firstBobChannel, - bobToCarol: secondBobChannel, - carolToBob: carolChannel, + aliceToBob: aliceChannel.channel, + bobToAlice: firstBobChannel.channel, + bobToCarol: secondBobChannel.channel, + carolToBob: carolChannel.channel, }, cleanUp, restoreFromDb, nil } @@ -969,13 +997,13 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, // createTwoClusterChannels creates lightning channels which are needed for // a 2 hop network cluster to be initialized. func createTwoClusterChannels(aliceToBob, bobToCarol btcutil.Amount) ( - *lnwallet.LightningChannel, *lnwallet.LightningChannel, + *testLightningChannel, *testLightningChannel, func(), error) { _, _, firstChanID, _ := genIDs() // Create lightning channels between Alice<->Bob and Bob<->Carol - aliceChannel, firstBobChannel, cleanAliceBob, _, err := + alice, bob, cleanAliceBob, err := createTestChannel(alicePrivKey, bobPrivKey, aliceToBob, aliceToBob, 0, 0, firstChanID) if err != nil { @@ -983,7 +1011,7 @@ func createTwoClusterChannels(aliceToBob, bobToCarol btcutil.Amount) ( "alice<->bob channel: %v", err) } - return aliceChannel, firstBobChannel, cleanAliceBob, nil + return alice, bob, cleanAliceBob, nil } // hopNetwork is the base struct for two and three hop networks