htlcswitch/test: use single channel restore function

This commit refactors test code around channel restoration in unit
tests to make it easier to use.
This commit is contained in:
Joost Jager 2019-04-08 11:29:18 +02:00
parent 3d17c2bcfe
commit e5ead599cc
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
2 changed files with 98 additions and 62 deletions

@ -180,7 +180,7 @@ func TestChannelLinkSingleHopPayment(t *testing.T) {
t.Parallel()
// Setup a alice-bob network.
aliceChannel, bobChannel, cleanUp, err := createTwoClusterChannels(
alice, bob, cleanUp, err := createTwoClusterChannels(
btcutil.SatoshiPerBitcoin*3,
btcutil.SatoshiPerBitcoin*5)
if err != nil {
@ -188,7 +188,9 @@ func TestChannelLinkSingleHopPayment(t *testing.T) {
}
defer cleanUp()
n := newTwoHopNetwork(t, aliceChannel, bobChannel, testStartingHeight)
n := newTwoHopNetwork(
t, alice.channel, bob.channel, testStartingHeight,
)
if err := n.start(); err != nil {
t.Fatal(err)
}
@ -1592,7 +1594,7 @@ func (m *mockPeer) Address() net.Addr {
func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
ChannelLink, *lnwallet.LightningChannel, chan time.Time, func() error,
func(), chanRestoreFunc, error) {
func(), func() (*lnwallet.LightningChannel, error), error) {
var chanIDBytes [8]byte
if _, err := io.ReadFull(rand.Reader, chanIDBytes[:]); err != nil {
@ -1602,7 +1604,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
chanID := lnwire.NewShortChanIDFromInt(
binary.BigEndian.Uint64(chanIDBytes[:]))
aliceChannel, bobChannel, fCleanUp, restore, err := createTestChannel(
aliceLc, bobLc, fCleanUp, err := createTestChannel(
alicePrivKey, bobPrivKey, chanAmt, chanAmt,
chanReserve, chanReserve, chanID,
)
@ -1628,7 +1630,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
pCache := newMockPreimageCache()
aliceDb := aliceChannel.State().Db
aliceDb := aliceLc.channel.State().Db
aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb)
if err != nil {
return nil, nil, nil, nil, nil, nil, err
@ -1668,7 +1670,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
}
const startingHeight = 100
aliceLink := NewChannelLink(aliceCfg, aliceChannel)
aliceLink := NewChannelLink(aliceCfg, aliceLc.channel)
start := func() error {
return aliceSwitch.AddLink(aliceLink)
}
@ -1687,7 +1689,8 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
defer fCleanUp()
}
return aliceLink, bobChannel, bticker.Force, start, cleanUp, restore, nil
return aliceLink, bobLc.channel, bticker.Force, start, cleanUp,
aliceLc.restore, nil
}
func assertLinkBandwidth(t *testing.T, link ChannelLink,
@ -2546,7 +2549,9 @@ func TestChannelLinkTrimCircuitsPending(t *testing.T) {
t.Fatalf("unable to start test harness: %v", err)
}
alice := newPersistentLinkHarness(t, aliceLink, batchTicker, restore)
alice := newPersistentLinkHarness(
t, aliceLink, batchTicker, restore,
)
// Compute the static fees that will be used to determine the
// correctness of Alice's bandwidth when forwarding HTLCs.
@ -2818,7 +2823,9 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) {
t.Fatalf("unable to start test harness: %v", err)
}
alice := newPersistentLinkHarness(t, aliceLink, batchTicker, restore)
alice := newPersistentLinkHarness(
t, aliceLink, batchTicker, restore,
)
// We'll put Alice into hodl.Commit mode, such that the circuits for any
// outgoing ADDs are opened, but the changes are not committed in the
@ -3980,14 +3987,15 @@ type persistentLinkHarness struct {
batchTicker chan time.Time
msgs chan lnwire.Message
restoreChan chanRestoreFunc
restoreChan func() (*lnwallet.LightningChannel, error)
}
// newPersistentLinkHarness initializes a new persistentLinkHarness and derives
// the supporting references from the active link.
func newPersistentLinkHarness(t *testing.T, link ChannelLink,
batchTicker chan time.Time,
restore chanRestoreFunc) *persistentLinkHarness {
restore func() (*lnwallet.LightningChannel,
error)) *persistentLinkHarness {
coreLink := link.(*channelLink)
@ -4034,7 +4042,7 @@ func (h *persistentLinkHarness) restart(restartSwitch bool,
// state, we will restore the persisted state to ensure we always start
// the link in a consistent state.
var err error
h.channel, _, err = h.restoreChan()
h.channel, err = h.restoreChan()
if err != nil {
h.t.Fatalf("unable to restore channels: %v", err)
}
@ -5573,7 +5581,7 @@ func TestChannelLinkCanceledInvoice(t *testing.T) {
t.Parallel()
// Setup a alice-bob network.
aliceChannel, bobChannel, cleanUp, err := createTwoClusterChannels(
alice, bob, cleanUp, err := createTwoClusterChannels(
btcutil.SatoshiPerBitcoin*3,
btcutil.SatoshiPerBitcoin*5)
if err != nil {
@ -5581,7 +5589,7 @@ func TestChannelLinkCanceledInvoice(t *testing.T) {
}
defer cleanUp()
n := newTwoHopNetwork(t, aliceChannel, bobChannel, testStartingHeight)
n := newTwoHopNetwork(t, alice.channel, bob.channel, testStartingHeight)
if err := n.start(); err != nil {
t.Fatal(err)
}
@ -5638,7 +5646,7 @@ type hodlInvoiceTestCtx struct {
func newHodlInvoiceTestCtx(t *testing.T) (*hodlInvoiceTestCtx, error) {
// Setup a alice-bob network.
aliceChannel, bobChannel, cleanUp, err := createTwoClusterChannels(
alice, bob, cleanUp, err := createTwoClusterChannels(
btcutil.SatoshiPerBitcoin*3,
btcutil.SatoshiPerBitcoin*5,
)
@ -5646,7 +5654,7 @@ func newHodlInvoiceTestCtx(t *testing.T) (*hodlInvoiceTestCtx, error) {
t.Fatalf("unable to create channel: %v", err)
}
n := newTwoHopNetwork(t, aliceChannel, bobChannel, testStartingHeight)
n := newTwoHopNetwork(t, alice.channel, bob.channel, testStartingHeight)
if err := n.start(); err != nil {
t.Fatal(err)
}

@ -147,15 +147,19 @@ func generateRandomBytes(n int) ([]byte, error) {
return b, nil
}
type testLightningChannel struct {
channel *lnwallet.LightningChannel
restore func() (*lnwallet.LightningChannel, error)
}
// createTestChannel creates the channel and returns our and remote channels
// representations.
//
// TODO(roasbeef): need to factor out, similar func re-used in many parts of codebase
func createTestChannel(alicePrivKey, bobPrivKey []byte,
aliceAmount, bobAmount, aliceReserve, bobReserve btcutil.Amount,
chanID lnwire.ShortChannelID) (*lnwallet.LightningChannel, *lnwallet.LightningChannel, func(),
func() (*lnwallet.LightningChannel, *lnwallet.LightningChannel,
error), error) {
chanID lnwire.ShortChannelID) (*testLightningChannel,
*testLightningChannel, func(), error) {
aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), alicePrivKey)
bobKeyPriv, bobKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), bobPrivKey)
@ -187,7 +191,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
var hash [sha256.Size]byte
randomSeed, err := generateRandomBytes(sha256.Size)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
copy(hash[:], randomSeed)
@ -236,23 +240,23 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize())
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot)
bobFirstRevoke, err := bobPreimageProducer.AtIndex(0)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:])
aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize())
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot)
aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:])
@ -260,25 +264,25 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
bobAmount, &aliceCfg, &bobCfg, aliceCommitPoint, bobCommitPoint,
*fundingTxIn)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
alicePath, err := ioutil.TempDir("", "alicedb")
dbAlice, err := channeldb.Open(alicePath)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
bobPath, err := ioutil.TempDir("", "bobdb")
dbBob, err := channeldb.Open(bobPath)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
estimator := lnwallet.NewStaticFeeEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
commitFee := feePerKw.FeeForWeight(724)
@ -350,11 +354,11 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
}
if err := aliceChannelState.SyncPending(bobAddr, broadcastHeight); err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
if err := bobChannelState.SyncPending(aliceAddr, broadcastHeight); err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
cleanUpFunc := func() {
@ -372,7 +376,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
aliceSigner, aliceChannelState, alicePool,
)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
alicePool.Start()
@ -381,7 +385,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
bobSigner, bobChannelState, bobPool,
)
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
bobPool.Start()
@ -389,40 +393,38 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
// having Alice and Bob extend their revocation windows to each other.
aliceNextRevoke, err := channelAlice.NextRevocationKey()
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
if err := channelBob.InitNextRevocation(aliceNextRevoke); err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
bobNextRevoke, err := channelBob.NextRevocationKey()
if err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
if err := channelAlice.InitNextRevocation(bobNextRevoke); err != nil {
return nil, nil, nil, nil, err
return nil, nil, nil, err
}
restore := func() (*lnwallet.LightningChannel, *lnwallet.LightningChannel,
error) {
restoreAlice := func() (*lnwallet.LightningChannel, error) {
aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub)
switch err {
case nil:
case bbolt.ErrDatabaseNotOpen:
dbAlice, err = channeldb.Open(dbAlice.Path())
if err != nil {
return nil, nil, errors.Errorf("unable to reopen alice "+
return nil, errors.Errorf("unable to reopen alice "+
"db: %v", err)
}
aliceStoredChannels, err = dbAlice.FetchOpenChannels(aliceKeyPub)
if err != nil {
return nil, nil, errors.Errorf("unable to fetch alice "+
return nil, errors.Errorf("unable to fetch alice "+
"channel: %v", err)
}
default:
return nil, nil, errors.Errorf("unable to fetch alice channel: "+
return nil, errors.Errorf("unable to fetch alice channel: "+
"%v", err)
}
@ -435,34 +437,38 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
}
if aliceStoredChannel == nil {
return nil, nil, errors.New("unable to find stored alice channel")
return nil, errors.New("unable to find stored alice channel")
}
newAliceChannel, err := lnwallet.NewLightningChannel(
aliceSigner, aliceStoredChannel, alicePool,
)
if err != nil {
return nil, nil, errors.Errorf("unable to create new channel: %v",
return nil, errors.Errorf("unable to create new channel: %v",
err)
}
return newAliceChannel, nil
}
restoreBob := func() (*lnwallet.LightningChannel, error) {
bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub)
switch err {
case nil:
case bbolt.ErrDatabaseNotOpen:
dbBob, err = channeldb.Open(dbBob.Path())
if err != nil {
return nil, nil, errors.Errorf("unable to reopen bob "+
return nil, errors.Errorf("unable to reopen bob "+
"db: %v", err)
}
bobStoredChannels, err = dbBob.FetchOpenChannels(bobKeyPub)
if err != nil {
return nil, nil, errors.Errorf("unable to fetch bob "+
return nil, errors.Errorf("unable to fetch bob "+
"channel: %v", err)
}
default:
return nil, nil, errors.Errorf("unable to fetch bob channel: "+
return nil, errors.Errorf("unable to fetch bob channel: "+
"%v", err)
}
@ -475,20 +481,31 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
}
if bobStoredChannel == nil {
return nil, nil, errors.New("unable to find stored bob channel")
return nil, errors.New("unable to find stored bob channel")
}
newBobChannel, err := lnwallet.NewLightningChannel(
bobSigner, bobStoredChannel, bobPool,
)
if err != nil {
return nil, nil, errors.Errorf("unable to create new channel: %v",
return nil, errors.Errorf("unable to create new channel: %v",
err)
}
return newAliceChannel, newBobChannel, nil
return newBobChannel, nil
}
return channelAlice, channelBob, cleanUpFunc, restore, nil
testLightningChannelAlice := &testLightningChannel{
channel: channelAlice,
restore: restoreAlice,
}
testLightningChannelBob := &testLightningChannel{
channel: channelBob,
restore: restoreBob,
}
return testLightningChannelAlice, testLightningChannelBob, cleanUpFunc,
nil
}
// getChanID retrieves the channel point from an lnnwire message.
@ -825,7 +842,7 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
_, _, firstChanID, secondChanID := genIDs()
// Create lightning channels between Alice<->Bob and Bob<->Carol
aliceChannel, firstBobChannel, cleanAliceBob, restoreAliceBob, err :=
aliceChannel, firstBobChannel, cleanAliceBob, err :=
createTestChannel(alicePrivKey, bobPrivKey, aliceToBob,
aliceToBob, 0, 0, firstChanID)
if err != nil {
@ -833,7 +850,7 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
"alice<->bob channel: %v", err)
}
secondBobChannel, carolChannel, cleanBobCarol, restoreBobCarol, err :=
secondBobChannel, carolChannel, cleanBobCarol, err :=
createTestChannel(bobPrivKey, carolPrivKey, bobToCarol,
bobToCarol, 0, 0, secondChanID)
if err != nil {
@ -848,12 +865,23 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
}
restoreFromDb := func() (*clusterChannels, error) {
a2b, b2a, err := restoreAliceBob()
a2b, err := aliceChannel.restore()
if err != nil {
return nil, err
}
b2c, c2b, err := restoreBobCarol()
b2a, err := firstBobChannel.restore()
if err != nil {
return nil, err
}
b2c, err := secondBobChannel.restore()
if err != nil {
return nil, err
}
c2b, err := carolChannel.restore()
if err != nil {
return nil, err
}
@ -867,10 +895,10 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
}
return &clusterChannels{
aliceToBob: aliceChannel,
bobToAlice: firstBobChannel,
bobToCarol: secondBobChannel,
carolToBob: carolChannel,
aliceToBob: aliceChannel.channel,
bobToAlice: firstBobChannel.channel,
bobToCarol: secondBobChannel.channel,
carolToBob: carolChannel.channel,
}, cleanUp, restoreFromDb, nil
}
@ -969,13 +997,13 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
// createTwoClusterChannels creates lightning channels which are needed for
// a 2 hop network cluster to be initialized.
func createTwoClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
*lnwallet.LightningChannel, *lnwallet.LightningChannel,
*testLightningChannel, *testLightningChannel,
func(), error) {
_, _, firstChanID, _ := genIDs()
// Create lightning channels between Alice<->Bob and Bob<->Carol
aliceChannel, firstBobChannel, cleanAliceBob, _, err :=
alice, bob, cleanAliceBob, err :=
createTestChannel(alicePrivKey, bobPrivKey, aliceToBob,
aliceToBob, 0, 0, firstChanID)
if err != nil {
@ -983,7 +1011,7 @@ func createTwoClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
"alice<->bob channel: %v", err)
}
return aliceChannel, firstBobChannel, cleanAliceBob, nil
return alice, bob, cleanAliceBob, nil
}
// hopNetwork is the base struct for two and three hop networks