htlcswitch/test: use single channel restore function
This commit refactors test code around channel restoration in unit tests to make it easier to use.
This commit is contained in:
parent
3d17c2bcfe
commit
e5ead599cc
@ -180,7 +180,7 @@ func TestChannelLinkSingleHopPayment(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup a alice-bob network.
|
||||
aliceChannel, bobChannel, cleanUp, err := createTwoClusterChannels(
|
||||
alice, bob, cleanUp, err := createTwoClusterChannels(
|
||||
btcutil.SatoshiPerBitcoin*3,
|
||||
btcutil.SatoshiPerBitcoin*5)
|
||||
if err != nil {
|
||||
@ -188,7 +188,9 @@ func TestChannelLinkSingleHopPayment(t *testing.T) {
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
n := newTwoHopNetwork(t, aliceChannel, bobChannel, testStartingHeight)
|
||||
n := newTwoHopNetwork(
|
||||
t, alice.channel, bob.channel, testStartingHeight,
|
||||
)
|
||||
if err := n.start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -1592,7 +1594,7 @@ func (m *mockPeer) Address() net.Addr {
|
||||
|
||||
func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
|
||||
ChannelLink, *lnwallet.LightningChannel, chan time.Time, func() error,
|
||||
func(), chanRestoreFunc, error) {
|
||||
func(), func() (*lnwallet.LightningChannel, error), error) {
|
||||
|
||||
var chanIDBytes [8]byte
|
||||
if _, err := io.ReadFull(rand.Reader, chanIDBytes[:]); err != nil {
|
||||
@ -1602,7 +1604,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
|
||||
chanID := lnwire.NewShortChanIDFromInt(
|
||||
binary.BigEndian.Uint64(chanIDBytes[:]))
|
||||
|
||||
aliceChannel, bobChannel, fCleanUp, restore, err := createTestChannel(
|
||||
aliceLc, bobLc, fCleanUp, err := createTestChannel(
|
||||
alicePrivKey, bobPrivKey, chanAmt, chanAmt,
|
||||
chanReserve, chanReserve, chanID,
|
||||
)
|
||||
@ -1628,7 +1630,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
|
||||
|
||||
pCache := newMockPreimageCache()
|
||||
|
||||
aliceDb := aliceChannel.State().Db
|
||||
aliceDb := aliceLc.channel.State().Db
|
||||
aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
@ -1668,7 +1670,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
|
||||
}
|
||||
|
||||
const startingHeight = 100
|
||||
aliceLink := NewChannelLink(aliceCfg, aliceChannel)
|
||||
aliceLink := NewChannelLink(aliceCfg, aliceLc.channel)
|
||||
start := func() error {
|
||||
return aliceSwitch.AddLink(aliceLink)
|
||||
}
|
||||
@ -1687,7 +1689,8 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
|
||||
defer fCleanUp()
|
||||
}
|
||||
|
||||
return aliceLink, bobChannel, bticker.Force, start, cleanUp, restore, nil
|
||||
return aliceLink, bobLc.channel, bticker.Force, start, cleanUp,
|
||||
aliceLc.restore, nil
|
||||
}
|
||||
|
||||
func assertLinkBandwidth(t *testing.T, link ChannelLink,
|
||||
@ -2546,7 +2549,9 @@ func TestChannelLinkTrimCircuitsPending(t *testing.T) {
|
||||
t.Fatalf("unable to start test harness: %v", err)
|
||||
}
|
||||
|
||||
alice := newPersistentLinkHarness(t, aliceLink, batchTicker, restore)
|
||||
alice := newPersistentLinkHarness(
|
||||
t, aliceLink, batchTicker, restore,
|
||||
)
|
||||
|
||||
// Compute the static fees that will be used to determine the
|
||||
// correctness of Alice's bandwidth when forwarding HTLCs.
|
||||
@ -2818,7 +2823,9 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) {
|
||||
t.Fatalf("unable to start test harness: %v", err)
|
||||
}
|
||||
|
||||
alice := newPersistentLinkHarness(t, aliceLink, batchTicker, restore)
|
||||
alice := newPersistentLinkHarness(
|
||||
t, aliceLink, batchTicker, restore,
|
||||
)
|
||||
|
||||
// We'll put Alice into hodl.Commit mode, such that the circuits for any
|
||||
// outgoing ADDs are opened, but the changes are not committed in the
|
||||
@ -3980,14 +3987,15 @@ type persistentLinkHarness struct {
|
||||
batchTicker chan time.Time
|
||||
msgs chan lnwire.Message
|
||||
|
||||
restoreChan chanRestoreFunc
|
||||
restoreChan func() (*lnwallet.LightningChannel, error)
|
||||
}
|
||||
|
||||
// newPersistentLinkHarness initializes a new persistentLinkHarness and derives
|
||||
// the supporting references from the active link.
|
||||
func newPersistentLinkHarness(t *testing.T, link ChannelLink,
|
||||
batchTicker chan time.Time,
|
||||
restore chanRestoreFunc) *persistentLinkHarness {
|
||||
restore func() (*lnwallet.LightningChannel,
|
||||
error)) *persistentLinkHarness {
|
||||
|
||||
coreLink := link.(*channelLink)
|
||||
|
||||
@ -4034,7 +4042,7 @@ func (h *persistentLinkHarness) restart(restartSwitch bool,
|
||||
// state, we will restore the persisted state to ensure we always start
|
||||
// the link in a consistent state.
|
||||
var err error
|
||||
h.channel, _, err = h.restoreChan()
|
||||
h.channel, err = h.restoreChan()
|
||||
if err != nil {
|
||||
h.t.Fatalf("unable to restore channels: %v", err)
|
||||
}
|
||||
@ -5573,7 +5581,7 @@ func TestChannelLinkCanceledInvoice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Setup a alice-bob network.
|
||||
aliceChannel, bobChannel, cleanUp, err := createTwoClusterChannels(
|
||||
alice, bob, cleanUp, err := createTwoClusterChannels(
|
||||
btcutil.SatoshiPerBitcoin*3,
|
||||
btcutil.SatoshiPerBitcoin*5)
|
||||
if err != nil {
|
||||
@ -5581,7 +5589,7 @@ func TestChannelLinkCanceledInvoice(t *testing.T) {
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
n := newTwoHopNetwork(t, aliceChannel, bobChannel, testStartingHeight)
|
||||
n := newTwoHopNetwork(t, alice.channel, bob.channel, testStartingHeight)
|
||||
if err := n.start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -5638,7 +5646,7 @@ type hodlInvoiceTestCtx struct {
|
||||
|
||||
func newHodlInvoiceTestCtx(t *testing.T) (*hodlInvoiceTestCtx, error) {
|
||||
// Setup a alice-bob network.
|
||||
aliceChannel, bobChannel, cleanUp, err := createTwoClusterChannels(
|
||||
alice, bob, cleanUp, err := createTwoClusterChannels(
|
||||
btcutil.SatoshiPerBitcoin*3,
|
||||
btcutil.SatoshiPerBitcoin*5,
|
||||
)
|
||||
@ -5646,7 +5654,7 @@ func newHodlInvoiceTestCtx(t *testing.T) (*hodlInvoiceTestCtx, error) {
|
||||
t.Fatalf("unable to create channel: %v", err)
|
||||
}
|
||||
|
||||
n := newTwoHopNetwork(t, aliceChannel, bobChannel, testStartingHeight)
|
||||
n := newTwoHopNetwork(t, alice.channel, bob.channel, testStartingHeight)
|
||||
if err := n.start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -147,15 +147,19 @@ func generateRandomBytes(n int) ([]byte, error) {
|
||||
return b, nil
|
||||
}
|
||||
|
||||
type testLightningChannel struct {
|
||||
channel *lnwallet.LightningChannel
|
||||
restore func() (*lnwallet.LightningChannel, error)
|
||||
}
|
||||
|
||||
// createTestChannel creates the channel and returns our and remote channels
|
||||
// representations.
|
||||
//
|
||||
// TODO(roasbeef): need to factor out, similar func re-used in many parts of codebase
|
||||
func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
aliceAmount, bobAmount, aliceReserve, bobReserve btcutil.Amount,
|
||||
chanID lnwire.ShortChannelID) (*lnwallet.LightningChannel, *lnwallet.LightningChannel, func(),
|
||||
func() (*lnwallet.LightningChannel, *lnwallet.LightningChannel,
|
||||
error), error) {
|
||||
chanID lnwire.ShortChannelID) (*testLightningChannel,
|
||||
*testLightningChannel, func(), error) {
|
||||
|
||||
aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), alicePrivKey)
|
||||
bobKeyPriv, bobKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), bobPrivKey)
|
||||
@ -187,7 +191,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
var hash [sha256.Size]byte
|
||||
randomSeed, err := generateRandomBytes(sha256.Size)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
copy(hash[:], randomSeed)
|
||||
|
||||
@ -236,23 +240,23 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
|
||||
bobRoot, err := chainhash.NewHash(bobKeyPriv.Serialize())
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
bobPreimageProducer := shachain.NewRevocationProducer(*bobRoot)
|
||||
bobFirstRevoke, err := bobPreimageProducer.AtIndex(0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
bobCommitPoint := input.ComputeCommitmentPoint(bobFirstRevoke[:])
|
||||
|
||||
aliceRoot, err := chainhash.NewHash(aliceKeyPriv.Serialize())
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
alicePreimageProducer := shachain.NewRevocationProducer(*aliceRoot)
|
||||
aliceFirstRevoke, err := alicePreimageProducer.AtIndex(0)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
aliceCommitPoint := input.ComputeCommitmentPoint(aliceFirstRevoke[:])
|
||||
|
||||
@ -260,25 +264,25 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
bobAmount, &aliceCfg, &bobCfg, aliceCommitPoint, bobCommitPoint,
|
||||
*fundingTxIn)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
alicePath, err := ioutil.TempDir("", "alicedb")
|
||||
dbAlice, err := channeldb.Open(alicePath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
bobPath, err := ioutil.TempDir("", "bobdb")
|
||||
dbBob, err := channeldb.Open(bobPath)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
estimator := lnwallet.NewStaticFeeEstimator(6000, 0)
|
||||
feePerKw, err := estimator.EstimateFeePerKW(1)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
commitFee := feePerKw.FeeForWeight(724)
|
||||
|
||||
@ -350,11 +354,11 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
}
|
||||
|
||||
if err := aliceChannelState.SyncPending(bobAddr, broadcastHeight); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if err := bobChannelState.SyncPending(aliceAddr, broadcastHeight); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
cleanUpFunc := func() {
|
||||
@ -372,7 +376,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
aliceSigner, aliceChannelState, alicePool,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
alicePool.Start()
|
||||
|
||||
@ -381,7 +385,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
bobSigner, bobChannelState, bobPool,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
bobPool.Start()
|
||||
|
||||
@ -389,40 +393,38 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
// having Alice and Bob extend their revocation windows to each other.
|
||||
aliceNextRevoke, err := channelAlice.NextRevocationKey()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if err := channelBob.InitNextRevocation(aliceNextRevoke); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
bobNextRevoke, err := channelBob.NextRevocationKey()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if err := channelAlice.InitNextRevocation(bobNextRevoke); err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
restore := func() (*lnwallet.LightningChannel, *lnwallet.LightningChannel,
|
||||
error) {
|
||||
|
||||
restoreAlice := func() (*lnwallet.LightningChannel, error) {
|
||||
aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub)
|
||||
switch err {
|
||||
case nil:
|
||||
case bbolt.ErrDatabaseNotOpen:
|
||||
dbAlice, err = channeldb.Open(dbAlice.Path())
|
||||
if err != nil {
|
||||
return nil, nil, errors.Errorf("unable to reopen alice "+
|
||||
return nil, errors.Errorf("unable to reopen alice "+
|
||||
"db: %v", err)
|
||||
}
|
||||
|
||||
aliceStoredChannels, err = dbAlice.FetchOpenChannels(aliceKeyPub)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Errorf("unable to fetch alice "+
|
||||
return nil, errors.Errorf("unable to fetch alice "+
|
||||
"channel: %v", err)
|
||||
}
|
||||
default:
|
||||
return nil, nil, errors.Errorf("unable to fetch alice channel: "+
|
||||
return nil, errors.Errorf("unable to fetch alice channel: "+
|
||||
"%v", err)
|
||||
}
|
||||
|
||||
@ -435,34 +437,38 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
}
|
||||
|
||||
if aliceStoredChannel == nil {
|
||||
return nil, nil, errors.New("unable to find stored alice channel")
|
||||
return nil, errors.New("unable to find stored alice channel")
|
||||
}
|
||||
|
||||
newAliceChannel, err := lnwallet.NewLightningChannel(
|
||||
aliceSigner, aliceStoredChannel, alicePool,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Errorf("unable to create new channel: %v",
|
||||
return nil, errors.Errorf("unable to create new channel: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
return newAliceChannel, nil
|
||||
}
|
||||
|
||||
restoreBob := func() (*lnwallet.LightningChannel, error) {
|
||||
bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub)
|
||||
switch err {
|
||||
case nil:
|
||||
case bbolt.ErrDatabaseNotOpen:
|
||||
dbBob, err = channeldb.Open(dbBob.Path())
|
||||
if err != nil {
|
||||
return nil, nil, errors.Errorf("unable to reopen bob "+
|
||||
return nil, errors.Errorf("unable to reopen bob "+
|
||||
"db: %v", err)
|
||||
}
|
||||
|
||||
bobStoredChannels, err = dbBob.FetchOpenChannels(bobKeyPub)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Errorf("unable to fetch bob "+
|
||||
return nil, errors.Errorf("unable to fetch bob "+
|
||||
"channel: %v", err)
|
||||
}
|
||||
default:
|
||||
return nil, nil, errors.Errorf("unable to fetch bob channel: "+
|
||||
return nil, errors.Errorf("unable to fetch bob channel: "+
|
||||
"%v", err)
|
||||
}
|
||||
|
||||
@ -475,20 +481,31 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
||||
}
|
||||
|
||||
if bobStoredChannel == nil {
|
||||
return nil, nil, errors.New("unable to find stored bob channel")
|
||||
return nil, errors.New("unable to find stored bob channel")
|
||||
}
|
||||
|
||||
newBobChannel, err := lnwallet.NewLightningChannel(
|
||||
bobSigner, bobStoredChannel, bobPool,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Errorf("unable to create new channel: %v",
|
||||
return nil, errors.Errorf("unable to create new channel: %v",
|
||||
err)
|
||||
}
|
||||
return newAliceChannel, newBobChannel, nil
|
||||
return newBobChannel, nil
|
||||
}
|
||||
|
||||
return channelAlice, channelBob, cleanUpFunc, restore, nil
|
||||
testLightningChannelAlice := &testLightningChannel{
|
||||
channel: channelAlice,
|
||||
restore: restoreAlice,
|
||||
}
|
||||
|
||||
testLightningChannelBob := &testLightningChannel{
|
||||
channel: channelBob,
|
||||
restore: restoreBob,
|
||||
}
|
||||
|
||||
return testLightningChannelAlice, testLightningChannelBob, cleanUpFunc,
|
||||
nil
|
||||
}
|
||||
|
||||
// getChanID retrieves the channel point from an lnnwire message.
|
||||
@ -825,7 +842,7 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
|
||||
_, _, firstChanID, secondChanID := genIDs()
|
||||
|
||||
// Create lightning channels between Alice<->Bob and Bob<->Carol
|
||||
aliceChannel, firstBobChannel, cleanAliceBob, restoreAliceBob, err :=
|
||||
aliceChannel, firstBobChannel, cleanAliceBob, err :=
|
||||
createTestChannel(alicePrivKey, bobPrivKey, aliceToBob,
|
||||
aliceToBob, 0, 0, firstChanID)
|
||||
if err != nil {
|
||||
@ -833,7 +850,7 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
|
||||
"alice<->bob channel: %v", err)
|
||||
}
|
||||
|
||||
secondBobChannel, carolChannel, cleanBobCarol, restoreBobCarol, err :=
|
||||
secondBobChannel, carolChannel, cleanBobCarol, err :=
|
||||
createTestChannel(bobPrivKey, carolPrivKey, bobToCarol,
|
||||
bobToCarol, 0, 0, secondChanID)
|
||||
if err != nil {
|
||||
@ -848,12 +865,23 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
|
||||
}
|
||||
|
||||
restoreFromDb := func() (*clusterChannels, error) {
|
||||
a2b, b2a, err := restoreAliceBob()
|
||||
|
||||
a2b, err := aliceChannel.restore()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b2c, c2b, err := restoreBobCarol()
|
||||
b2a, err := firstBobChannel.restore()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b2c, err := secondBobChannel.restore()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c2b, err := carolChannel.restore()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -867,10 +895,10 @@ func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
|
||||
}
|
||||
|
||||
return &clusterChannels{
|
||||
aliceToBob: aliceChannel,
|
||||
bobToAlice: firstBobChannel,
|
||||
bobToCarol: secondBobChannel,
|
||||
carolToBob: carolChannel,
|
||||
aliceToBob: aliceChannel.channel,
|
||||
bobToAlice: firstBobChannel.channel,
|
||||
bobToCarol: secondBobChannel.channel,
|
||||
carolToBob: carolChannel.channel,
|
||||
}, cleanUp, restoreFromDb, nil
|
||||
}
|
||||
|
||||
@ -969,13 +997,13 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
|
||||
// createTwoClusterChannels creates lightning channels which are needed for
|
||||
// a 2 hop network cluster to be initialized.
|
||||
func createTwoClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
|
||||
*lnwallet.LightningChannel, *lnwallet.LightningChannel,
|
||||
*testLightningChannel, *testLightningChannel,
|
||||
func(), error) {
|
||||
|
||||
_, _, firstChanID, _ := genIDs()
|
||||
|
||||
// Create lightning channels between Alice<->Bob and Bob<->Carol
|
||||
aliceChannel, firstBobChannel, cleanAliceBob, _, err :=
|
||||
alice, bob, cleanAliceBob, err :=
|
||||
createTestChannel(alicePrivKey, bobPrivKey, aliceToBob,
|
||||
aliceToBob, 0, 0, firstChanID)
|
||||
if err != nil {
|
||||
@ -983,7 +1011,7 @@ func createTwoClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
|
||||
"alice<->bob channel: %v", err)
|
||||
}
|
||||
|
||||
return aliceChannel, firstBobChannel, cleanAliceBob, nil
|
||||
return alice, bob, cleanAliceBob, nil
|
||||
}
|
||||
|
||||
// hopNetwork is the base struct for two and three hop networks
|
||||
|
Loading…
Reference in New Issue
Block a user