Merge pull request #1229 from wpaulino/randomize-link-fee-updates

htlcswitch: randomize link fee updates
This commit is contained in:
Olaoluwa Osuntokun 2018-06-13 19:11:48 -07:00 committed by GitHub
commit 39f2739d84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 236 additions and 261 deletions

@ -2,16 +2,15 @@ package htlcswitch
import ( import (
"bytes" "bytes"
"crypto/sha256"
"fmt" "fmt"
prand "math/rand"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"crypto/sha256"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hodl"
@ -21,6 +20,10 @@ import (
"github.com/roasbeef/btcd/chaincfg/chainhash" "github.com/roasbeef/btcd/chaincfg/chainhash"
) )
func init() {
prand.Seed(time.Now().UnixNano())
}
const ( const (
// expiryGraceDelta is a grace period that the timeout of incoming // expiryGraceDelta is a grace period that the timeout of incoming
// HTLC's that pay directly to us (i.e we're the "exit node") must up // HTLC's that pay directly to us (i.e we're the "exit node") must up
@ -36,6 +39,12 @@ const (
// for a new fee update. We'll use this as a fee floor when proposing // for a new fee update. We'll use this as a fee floor when proposing
// and accepting updates. // and accepting updates.
minCommitFeePerKw = 253 minCommitFeePerKw = 253
// DefaultMinLinkFeeUpdateTimeout and DefaultMaxLinkFeeUpdateTimeout
// represent the default timeout bounds in which a link should propose
// to update its commitment fee rate.
DefaultMinLinkFeeUpdateTimeout = 10 * time.Minute
DefaultMaxLinkFeeUpdateTimeout = 60 * time.Minute
) )
// ForwardingPolicy describes the set of constraints that a given ChannelLink // ForwardingPolicy describes the set of constraints that a given ChannelLink
@ -202,13 +211,6 @@ type ChannelLinkConfig struct {
// transaction to ensure timely confirmation. // transaction to ensure timely confirmation.
FeeEstimator lnwallet.FeeEstimator FeeEstimator lnwallet.FeeEstimator
// BlockEpochs is an active block epoch event stream backed by an
// active ChainNotifier instance. The ChannelLink will use new block
// notifications sent over this channel to decide when a _new_ HTLC is
// too close to expiry, and also when any active HTLC's have expired
// (or are close to expiry).
BlockEpochs *chainntnfs.BlockEpochEvent
// DebugHTLC should be turned on if you want all HTLCs sent to a node // DebugHTLC should be turned on if you want all HTLCs sent to a node
// with the debug htlc R-Hash are immediately settled in the next // with the debug htlc R-Hash are immediately settled in the next
// available state transition. // available state transition.
@ -248,6 +250,12 @@ type ChannelLinkConfig struct {
// in testing, it is here to ensure the sphinx replay detection on the // in testing, it is here to ensure the sphinx replay detection on the
// receiving node is persistent. // receiving node is persistent.
UnsafeReplay bool UnsafeReplay bool
// MinFeeUpdateTimeout and MaxFeeUpdateTimeout represent the timeout
// interval bounds in which a link will propose to update its commitment
// fee rate. A random timeout will be selected between these values.
MinFeeUpdateTimeout time.Duration
MaxFeeUpdateTimeout time.Duration
} }
// channelLink is the service which drives a channel's commitment update // channelLink is the service which drives a channel's commitment update
@ -273,10 +281,6 @@ type channelLink struct {
// method in state machine. // method in state machine.
batchCounter uint32 batchCounter uint32
// bestHeight is the best known height of the main chain. The link will
// use this information to govern decisions based on HTLC timeouts.
bestHeight uint32
// keystoneBatch represents a volatile list of keystones that must be // keystoneBatch represents a volatile list of keystones that must be
// written before attempting to sign the next commitment txn. These // written before attempting to sign the next commitment txn. These
// represent all the HTLC's forwarded to the link from the switch. Once // represent all the HTLC's forwarded to the link from the switch. Once
@ -342,6 +346,10 @@ type channelLink struct {
logCommitTimer *time.Timer logCommitTimer *time.Timer
logCommitTick <-chan time.Time logCommitTick <-chan time.Time
// updateFeeTimer is the timer responsible for updating the link's
// commitment fee every time it fires.
updateFeeTimer *time.Timer
sync.RWMutex sync.RWMutex
wg sync.WaitGroup wg sync.WaitGroup
@ -350,8 +358,8 @@ type channelLink struct {
// NewChannelLink creates a new instance of a ChannelLink given a configuration // NewChannelLink creates a new instance of a ChannelLink given a configuration
// and active channel that will be used to verify/apply updates to. // and active channel that will be used to verify/apply updates to.
func NewChannelLink(cfg ChannelLinkConfig, channel *lnwallet.LightningChannel, func NewChannelLink(cfg ChannelLinkConfig,
currentHeight uint32) ChannelLink { channel *lnwallet.LightningChannel) ChannelLink {
return &channelLink{ return &channelLink{
cfg: cfg, cfg: cfg,
@ -360,7 +368,6 @@ func NewChannelLink(cfg ChannelLinkConfig, channel *lnwallet.LightningChannel,
// TODO(roasbeef): just do reserve here? // TODO(roasbeef): just do reserve here?
logCommitTimer: time.NewTimer(300 * time.Millisecond), logCommitTimer: time.NewTimer(300 * time.Millisecond),
overflowQueue: newPacketQueue(lnwallet.MaxHTLCNumber / 2), overflowQueue: newPacketQueue(lnwallet.MaxHTLCNumber / 2),
bestHeight: currentHeight,
htlcUpdates: make(chan []channeldb.HTLC), htlcUpdates: make(chan []channeldb.HTLC),
quit: make(chan struct{}), quit: make(chan struct{}),
} }
@ -427,6 +434,8 @@ func (l *channelLink) Start() error {
} }
} }
l.updateFeeTimer = time.NewTimer(l.randomFeeUpdateTimeout())
l.wg.Add(1) l.wg.Add(1)
go l.htlcManager() go l.htlcManager()
@ -449,8 +458,8 @@ func (l *channelLink) Stop() {
l.cfg.ChainEvents.Cancel() l.cfg.ChainEvents.Cancel()
} }
l.updateFeeTimer.Stop()
l.channel.Stop() l.channel.Stop()
l.overflowQueue.Stop() l.overflowQueue.Stop()
close(l.quit) close(l.quit)
@ -781,7 +790,6 @@ func (l *channelLink) fwdPkgGarbager() {
func (l *channelLink) htlcManager() { func (l *channelLink) htlcManager() {
defer func() { defer func() {
l.wg.Done() l.wg.Done()
l.cfg.BlockEpochs.Cancel()
log.Infof("ChannelLink(%v) has exited", l) log.Infof("ChannelLink(%v) has exited", l)
}() }()
@ -835,7 +843,6 @@ func (l *channelLink) htlcManager() {
out: out:
for { for {
// We must always check if we failed at some point processing // We must always check if we failed at some point processing
// the last update before processing the next. // the last update before processing the next.
if l.failed { if l.failed {
@ -844,16 +851,10 @@ out:
} }
select { select {
// Our update fee timer has fired, so we'll check the network
// A new block has arrived, we'll check the network fee to see // fee to see if we should adjust our commitment fee.
// if we should adjust our commitment fee, and also update our case <-l.updateFeeTimer.C:
// track of the best current height. l.updateFeeTimer.Reset(l.randomFeeUpdateTimeout())
case blockEpoch, ok := <-l.cfg.BlockEpochs.Epochs:
if !ok {
break out
}
l.bestHeight = uint32(blockEpoch.Height)
// If we're not the initiator of the channel, don't we // If we're not the initiator of the channel, don't we
// don't control the fees, so we can ignore this. // don't control the fees, so we can ignore this.
@ -983,6 +984,20 @@ out:
} }
} }
// randomFeeUpdateTimeout returns a random timeout between the bounds defined
// within the link's configuration that will be used to determine when the link
// should propose an update to its commitment fee rate.
func (l *channelLink) randomFeeUpdateTimeout() time.Duration {
lower := int64(l.cfg.MinFeeUpdateTimeout)
upper := int64(l.cfg.MaxFeeUpdateTimeout)
rand := prand.Int63n(upper)
if rand < lower {
rand = lower
}
return time.Duration(rand)
}
// handleDownStreamPkt processes an HTLC packet sent from the downstream HTLC // handleDownStreamPkt processes an HTLC packet sent from the downstream HTLC
// Switch. Possible messages sent by the switch include requests to forward new // Switch. Possible messages sent by the switch include requests to forward new
// HTLCs, timeout previously cleared HTLCs, and finally to settle currently // HTLCs, timeout previously cleared HTLCs, and finally to settle currently
@ -2065,7 +2080,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
continue continue
} }
heightNow := l.bestHeight heightNow := l.cfg.Switch.BestHeight()
fwdInfo := chanIterator.ForwardingInstructions() fwdInfo := chanIterator.ForwardingInstructions()
switch fwdInfo.NextHop { switch fwdInfo.NextHop {

@ -6,6 +6,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"math"
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
@ -13,12 +14,9 @@ import (
"testing" "testing"
"time" "time"
"math"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hodl"
@ -1057,7 +1055,7 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) {
htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight, htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight,
n.firstBobChannelLink, n.carolChannelLink) n.firstBobChannelLink, n.carolChannelLink)
daveServer, err := newMockServer(t, "dave", nil) daveServer, err := newMockServer(t, "dave", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to init dave's server: %v", err) t.Fatalf("unable to init dave's server: %v", err)
} }
@ -1443,11 +1441,6 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
} }
var ( var (
globalEpoch = &chainntnfs.BlockEpochEvent{
Epochs: make(chan *chainntnfs.BlockEpoch),
Cancel: func() {
},
}
invoiceRegistry = newMockRegistry() invoiceRegistry = newMockRegistry()
decoder = newMockIteratorDecoder() decoder = newMockIteratorDecoder()
obfuscator = NewMockObfuscator() obfuscator = NewMockObfuscator()
@ -1468,7 +1461,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
} }
aliceDb := aliceChannel.State().Db aliceDb := aliceChannel.State().Db
aliceSwitch, err := initSwitchWithDB(aliceDb) aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb)
if err != nil { if err != nil {
return nil, nil, nil, nil, nil, nil, err return nil, nil, nil, nil, nil, nil, err
} }
@ -1495,16 +1488,17 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
}, },
Registry: invoiceRegistry, Registry: invoiceRegistry,
ChainEvents: &contractcourt.ChainEventSubscription{}, ChainEvents: &contractcourt.ChainEventSubscription{},
BlockEpochs: globalEpoch,
BatchTicker: ticker, BatchTicker: ticker,
FwdPkgGCTicker: NewBatchTicker(time.NewTicker(5 * time.Second)), FwdPkgGCTicker: NewBatchTicker(time.NewTicker(5 * time.Second)),
// Make the BatchSize large enough to not // Make the BatchSize and Min/MaxFeeUpdateTimeout large enough
// trigger commit update automatically during tests. // to not trigger commit updates automatically during tests.
BatchSize: 10000, BatchSize: 10000,
MinFeeUpdateTimeout: 30 * time.Minute,
MaxFeeUpdateTimeout: 30 * time.Minute,
} }
const startingHeight = 100 const startingHeight = 100
aliceLink := NewChannelLink(aliceCfg, aliceChannel, startingHeight) aliceLink := NewChannelLink(aliceCfg, aliceChannel)
start := func() error { start := func() error {
return aliceSwitch.AddLink(aliceLink) return aliceSwitch.AddLink(aliceLink)
} }
@ -3451,22 +3445,9 @@ func TestChannelLinkUpdateCommitFee(t *testing.T) {
defer n.stop() defer n.stop()
defer n.feeEstimator.Stop() defer n.feeEstimator.Stop()
// First, we'll start off all channels at "height" 9000 by sending a // For the sake of this test, we'll reset the timer to fire in a second
// new epoch to all the clients. // so that Alice's link queries for a new network fee.
select { n.aliceChannelLink.updateFeeTimer.Reset(time.Millisecond)
case n.aliceBlockEpoch <- &chainntnfs.BlockEpoch{
Height: 9000,
}:
case <-time.After(time.Second * 5):
t.Fatalf("link didn't read block epoch")
}
select {
case n.bobFirstBlockEpoch <- &chainntnfs.BlockEpoch{
Height: 9000,
}:
case <-time.After(time.Second * 5):
t.Fatalf("link didn't read block epoch")
}
startingFeeRate := channels.aliceToBob.CommitFeeRate() startingFeeRate := channels.aliceToBob.CommitFeeRate()
@ -3480,20 +3461,15 @@ func TestChannelLinkUpdateCommitFee(t *testing.T) {
select { select {
case n.feeEstimator.byteFeeIn <- startingFeeRateSatPerVByte: case n.feeEstimator.byteFeeIn <- startingFeeRateSatPerVByte:
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):
t.Fatalf("alice didn't query for the new " + t.Fatalf("alice didn't query for the new network fee")
"network fee")
} }
time.Sleep(time.Millisecond * 500) time.Sleep(time.Second)
// The fee rate on the alice <-> bob channel should still be the same // The fee rate on the alice <-> bob channel should still be the same
// on both sides. // on both sides.
aliceFeeRate := channels.aliceToBob.CommitFeeRate() aliceFeeRate := channels.aliceToBob.CommitFeeRate()
bobFeeRate := channels.bobToAlice.CommitFeeRate() bobFeeRate := channels.bobToAlice.CommitFeeRate()
if aliceFeeRate != bobFeeRate {
t.Fatalf("fee rates don't match: expected %v got %v",
aliceFeeRate, bobFeeRate)
}
if aliceFeeRate != startingFeeRate { if aliceFeeRate != startingFeeRate {
t.Fatalf("alice's fee rate shouldn't have changed: "+ t.Fatalf("alice's fee rate shouldn't have changed: "+
"expected %v, got %v", aliceFeeRate, startingFeeRate) "expected %v, got %v", aliceFeeRate, startingFeeRate)
@ -3503,22 +3479,9 @@ func TestChannelLinkUpdateCommitFee(t *testing.T) {
"expected %v, got %v", bobFeeRate, startingFeeRate) "expected %v, got %v", bobFeeRate, startingFeeRate)
} }
// Now we'll send a new block update to all end points, with a new // We'll reset the timer once again to ensure Alice's link queries for a
// height THAT'S OVER 9000!!! // new network fee.
select { n.aliceChannelLink.updateFeeTimer.Reset(time.Millisecond)
case n.aliceBlockEpoch <- &chainntnfs.BlockEpoch{
Height: 9001,
}:
case <-time.After(time.Second * 5):
t.Fatalf("link didn't read block epoch")
}
select {
case n.bobFirstBlockEpoch <- &chainntnfs.BlockEpoch{
Height: 9001,
}:
case <-time.After(time.Second * 5):
t.Fatalf("link didn't read block epoch")
}
// Next, we'll set up a deliver a fee rate that's triple the current // Next, we'll set up a deliver a fee rate that's triple the current
// fee rate. This should cause the Alice (the initiator) to trigger a // fee rate. This should cause the Alice (the initiator) to trigger a
@ -3527,11 +3490,10 @@ func TestChannelLinkUpdateCommitFee(t *testing.T) {
select { select {
case n.feeEstimator.byteFeeIn <- startingFeeRateSatPerVByte * 3: case n.feeEstimator.byteFeeIn <- startingFeeRateSatPerVByte * 3:
case <-time.After(time.Second * 5): case <-time.After(time.Second * 5):
t.Fatalf("alice didn't query for the new " + t.Fatalf("alice didn't query for the new network fee")
"network fee")
} }
time.Sleep(time.Second * 2) time.Sleep(time.Second)
// At this point, Alice should've triggered a new fee update that // At this point, Alice should've triggered a new fee update that
// increased the fee rate to match the new rate. // increased the fee rate to match the new rate.
@ -3545,10 +3507,6 @@ func TestChannelLinkUpdateCommitFee(t *testing.T) {
t.Fatalf("bob's fee rate didn't change: expected %v, got %v", t.Fatalf("bob's fee rate didn't change: expected %v, got %v",
newFeeRate, aliceFeeRate) newFeeRate, aliceFeeRate)
} }
if aliceFeeRate != bobFeeRate {
t.Fatalf("fee rates don't match: expected %v got %v",
aliceFeeRate, bobFeeRate)
}
} }
// TestChannelLinkAcceptDuplicatePayment tests that if a link receives an // TestChannelLinkAcceptDuplicatePayment tests that if a link receives an
@ -3859,11 +3817,6 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
hodlFlags []hodl.Flag) (ChannelLink, chan time.Time, func(), error) { hodlFlags []hodl.Flag) (ChannelLink, chan time.Time, func(), error) {
var ( var (
globalEpoch = &chainntnfs.BlockEpochEvent{
Epochs: make(chan *chainntnfs.BlockEpoch),
Cancel: func() {
},
}
invoiceRegistry = newMockRegistry() invoiceRegistry = newMockRegistry()
decoder = newMockIteratorDecoder() decoder = newMockIteratorDecoder()
obfuscator = NewMockObfuscator() obfuscator = NewMockObfuscator()
@ -3888,7 +3841,7 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
if aliceSwitch == nil { if aliceSwitch == nil {
var err error var err error
aliceSwitch, err = initSwitchWithDB(aliceDb) aliceSwitch, err = initSwitchWithDB(testStartingHeight, aliceDb)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -3914,19 +3867,20 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
}, },
Registry: invoiceRegistry, Registry: invoiceRegistry,
ChainEvents: &contractcourt.ChainEventSubscription{}, ChainEvents: &contractcourt.ChainEventSubscription{},
BlockEpochs: globalEpoch,
BatchTicker: ticker, BatchTicker: ticker,
FwdPkgGCTicker: NewBatchTicker(time.NewTicker(5 * time.Second)), FwdPkgGCTicker: NewBatchTicker(time.NewTicker(5 * time.Second)),
// Make the BatchSize large enough to not // Make the BatchSize and Min/MaxFeeUpdateTimeout large enough
// trigger commit update automatically during tests. // to not trigger commit updates automatically during tests.
BatchSize: 10000, BatchSize: 10000,
MinFeeUpdateTimeout: 30 * time.Minute,
MaxFeeUpdateTimeout: 30 * time.Minute,
// Set any hodl flags requested for the new link. // Set any hodl flags requested for the new link.
HodlMask: hodl.MaskFromFlags(hodlFlags...), HodlMask: hodl.MaskFromFlags(hodlFlags...),
DebugHTLC: len(hodlFlags) > 0, DebugHTLC: len(hodlFlags) > 0,
} }
const startingHeight = 100 const startingHeight = 100
aliceLink := NewChannelLink(aliceCfg, aliceChannel, startingHeight) aliceLink := NewChannelLink(aliceCfg, aliceChannel)
if err := aliceSwitch.AddLink(aliceLink); err != nil { if err := aliceSwitch.AddLink(aliceLink); err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }

@ -1,19 +1,17 @@
package htlcswitch package htlcswitch
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"io"
"sync/atomic"
"bytes"
"github.com/btcsuite/fastsha256" "github.com/btcsuite/fastsha256"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lightning-onion"
@ -122,7 +120,7 @@ type mockServer struct {
var _ lnpeer.Peer = (*mockServer)(nil) var _ lnpeer.Peer = (*mockServer)(nil)
func initSwitchWithDB(db *channeldb.DB) (*Switch, error) { func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) {
if db == nil { if db == nil {
tempPath, err := ioutil.TempDir("", "switchdb") tempPath, err := ioutil.TempDir("", "switchdb")
if err != nil { if err != nil {
@ -135,7 +133,7 @@ func initSwitchWithDB(db *channeldb.DB) (*Switch, error) {
} }
} }
return New(Config{ cfg := Config{
DB: db, DB: db,
SwitchPackager: channeldb.NewSwitchPackager(), SwitchPackager: channeldb.NewSwitchPackager(),
FwdingLog: &mockForwardingLog{ FwdingLog: &mockForwardingLog{
@ -144,15 +142,20 @@ func initSwitchWithDB(db *channeldb.DB) (*Switch, error) {
FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
return nil, nil return nil, nil
}, },
}) Notifier: &mockNotifier{},
}
return New(cfg, startingHeight)
} }
func newMockServer(t testing.TB, name string, db *channeldb.DB) (*mockServer, error) { func newMockServer(t testing.TB, name string, startingHeight uint32,
db *channeldb.DB) (*mockServer, error) {
var id [33]byte var id [33]byte
h := sha256.Sum256([]byte(name)) h := sha256.Sum256([]byte(name))
copy(id[:], h[:]) copy(id[:], h[:])
htlcSwitch, err := initSwitchWithDB(db) htlcSwitch, err := initSwitchWithDB(startingHeight, db)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -2,23 +2,22 @@ package htlcswitch
import ( import (
"bytes" "bytes"
"crypto/sha256"
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"crypto/sha256"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/roasbeef/btcd/btcec"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/btcec"
"github.com/roasbeef/btcd/wire" "github.com/roasbeef/btcd/wire"
"github.com/roasbeef/btcutil" "github.com/roasbeef/btcutil"
) )
@ -142,6 +141,10 @@ type Config struct {
// provide payment senders our latest policy when sending encrypted // provide payment senders our latest policy when sending encrypted
// error messages. // error messages.
FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error)
// Notifier is an instance of a chain notifier that we'll use to signal
// the switch when a new block has arrived.
Notifier chainntnfs.ChainNotifier
} }
// Switch is the central messaging bus for all incoming/outgoing HTLCs. // Switch is the central messaging bus for all incoming/outgoing HTLCs.
@ -155,6 +158,12 @@ type Config struct {
type Switch struct { type Switch struct {
started int32 // To be used atomically. started int32 // To be used atomically.
shutdown int32 // To be used atomically. shutdown int32 // To be used atomically.
// bestHeight is the best known height of the main chain. The links will
// be used this information to govern decisions based on HTLC timeouts.
// This will be retrieved by the registered links atomically.
bestHeight uint32
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{} quit chan struct{}
@ -229,10 +238,15 @@ type Switch struct {
// to the forwarding log. // to the forwarding log.
fwdEventMtx sync.Mutex fwdEventMtx sync.Mutex
pendingFwdingEvents []channeldb.ForwardingEvent pendingFwdingEvents []channeldb.ForwardingEvent
// blockEpochStream is an active block epoch event stream backed by an
// active ChainNotifier instance. This will be used to retrieve the
// lastest height of the chain.
blockEpochStream *chainntnfs.BlockEpochEvent
} }
// New creates the new instance of htlc switch. // New creates the new instance of htlc switch.
func New(cfg Config) (*Switch, error) { func New(cfg Config, currentHeight uint32) (*Switch, error) {
circuitMap, err := NewCircuitMap(&CircuitMapConfig{ circuitMap, err := NewCircuitMap(&CircuitMapConfig{
DB: cfg.DB, DB: cfg.DB,
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter, ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
@ -247,6 +261,7 @@ func New(cfg Config) (*Switch, error) {
} }
return &Switch{ return &Switch{
bestHeight: currentHeight,
cfg: &cfg, cfg: &cfg,
circuits: circuitMap, circuits: circuitMap,
paymentSequencer: sequencer, paymentSequencer: sequencer,
@ -1339,8 +1354,10 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, closeType ChannelCloseType,
func (s *Switch) htlcForwarder() { func (s *Switch) htlcForwarder() {
defer s.wg.Done() defer s.wg.Done()
// Remove all links once we've been signalled for shutdown.
defer func() { defer func() {
s.blockEpochStream.Cancel()
// Remove all links once we've been signalled for shutdown.
s.indexMtx.Lock() s.indexMtx.Lock()
for _, link := range s.linkIndex { for _, link := range s.linkIndex {
if err := s.removeLink(link.ChanID()); err != nil { if err := s.removeLink(link.ChanID()); err != nil {
@ -1378,8 +1395,15 @@ func (s *Switch) htlcForwarder() {
fwdEventTicker := time.NewTicker(15 * time.Second) fwdEventTicker := time.NewTicker(15 * time.Second)
defer fwdEventTicker.Stop() defer fwdEventTicker.Stop()
out:
for { for {
select { select {
case blockEpoch, ok := <-s.blockEpochStream.Epochs:
if !ok {
break out
}
atomic.StoreUint32(&s.bestHeight, uint32(blockEpoch.Height))
// A local close request has arrived, we'll forward this to the // A local close request has arrived, we'll forward this to the
// relevant link (if it exists) so the channel can be // relevant link (if it exists) so the channel can be
// cooperatively closed (if possible). // cooperatively closed (if possible).
@ -1549,6 +1573,12 @@ func (s *Switch) Start() error {
log.Infof("Starting HTLC Switch") log.Infof("Starting HTLC Switch")
blockEpochStream, err := s.cfg.Notifier.RegisterBlockEpochNtfn()
if err != nil {
return err
}
s.blockEpochStream = blockEpochStream
s.wg.Add(1) s.wg.Add(1)
go s.htlcForwarder() go s.htlcForwarder()
@ -2033,3 +2063,8 @@ func (s *Switch) FlushForwardingEvents() error {
// forwarding log. // forwarding log.
return s.cfg.FwdingLog.AddForwardingEvents(events) return s.cfg.FwdingLog.AddForwardingEvents(events)
} }
// BestHeight returns the best height known to the switch.
func (s *Switch) BestHeight() uint32 {
return atomic.LoadUint32(&s.bestHeight)
}

@ -30,12 +30,12 @@ func genPreimage() ([32]byte, error) {
func TestSwitchSendPending(t *testing.T) { func TestSwitchSendPending(t *testing.T) {
t.Parallel() t.Parallel()
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
s, err := initSwitchWithDB(nil) s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -125,16 +125,16 @@ func TestSwitchSendPending(t *testing.T) {
func TestSwitchForward(t *testing.T) { func TestSwitchForward(t *testing.T) {
t.Parallel() t.Parallel()
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
bobPeer, err := newMockServer(t, "bob", nil) bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create bob server: %v", err) t.Fatalf("unable to create bob server: %v", err)
} }
s, err := initSwitchWithDB(nil) s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -230,11 +230,11 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs() chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
bobPeer, err := newMockServer(t, "bob", nil) bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create bob server: %v", err) t.Fatalf("unable to create bob server: %v", err)
} }
@ -249,7 +249,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
t.Fatalf("unable to open channeldb: %v", err) t.Fatalf("unable to open channeldb: %v", err)
} }
s, err := initSwitchWithDB(cdb) s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -344,7 +344,7 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
t.Fatalf("unable to reopen channeldb: %v", err) t.Fatalf("unable to reopen channeldb: %v", err)
} }
s2, err := initSwitchWithDB(cdb2) s2, err := initSwitchWithDB(testStartingHeight, cdb2)
if err != nil { if err != nil {
t.Fatalf("unable reinit switch: %v", err) t.Fatalf("unable reinit switch: %v", err)
} }
@ -421,11 +421,11 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs() chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
bobPeer, err := newMockServer(t, "bob", nil) bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create bob server: %v", err) t.Fatalf("unable to create bob server: %v", err)
} }
@ -440,7 +440,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
t.Fatalf("unable to open channeldb: %v", err) t.Fatalf("unable to open channeldb: %v", err)
} }
s, err := initSwitchWithDB(cdb) s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -535,7 +535,7 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
t.Fatalf("unable to reopen channeldb: %v", err) t.Fatalf("unable to reopen channeldb: %v", err)
} }
s2, err := initSwitchWithDB(cdb2) s2, err := initSwitchWithDB(testStartingHeight, cdb2)
if err != nil { if err != nil {
t.Fatalf("unable reinit switch: %v", err) t.Fatalf("unable reinit switch: %v", err)
} }
@ -615,11 +615,11 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs() chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
bobPeer, err := newMockServer(t, "bob", nil) bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create bob server: %v", err) t.Fatalf("unable to create bob server: %v", err)
} }
@ -634,7 +634,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
t.Fatalf("unable to open channeldb: %v", err) t.Fatalf("unable to open channeldb: %v", err)
} }
s, err := initSwitchWithDB(cdb) s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -721,7 +721,7 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
t.Fatalf("unable to reopen channeldb: %v", err) t.Fatalf("unable to reopen channeldb: %v", err)
} }
s2, err := initSwitchWithDB(cdb2) s2, err := initSwitchWithDB(testStartingHeight, cdb2)
if err != nil { if err != nil {
t.Fatalf("unable reinit switch: %v", err) t.Fatalf("unable reinit switch: %v", err)
} }
@ -778,11 +778,11 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs() chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
bobPeer, err := newMockServer(t, "bob", nil) bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create bob server: %v", err) t.Fatalf("unable to create bob server: %v", err)
} }
@ -797,7 +797,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
t.Fatalf("unable to open channeldb: %v", err) t.Fatalf("unable to open channeldb: %v", err)
} }
s, err := initSwitchWithDB(cdb) s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -879,7 +879,7 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
t.Fatalf("unable to reopen channeldb: %v", err) t.Fatalf("unable to reopen channeldb: %v", err)
} }
s2, err := initSwitchWithDB(cdb2) s2, err := initSwitchWithDB(testStartingHeight, cdb2)
if err != nil { if err != nil {
t.Fatalf("unable reinit switch: %v", err) t.Fatalf("unable reinit switch: %v", err)
} }
@ -936,11 +936,11 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs() chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
bobPeer, err := newMockServer(t, "bob", nil) bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create bob server: %v", err) t.Fatalf("unable to create bob server: %v", err)
} }
@ -955,7 +955,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) {
t.Fatalf("unable to open channeldb: %v", err) t.Fatalf("unable to open channeldb: %v", err)
} }
s, err := initSwitchWithDB(cdb) s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -1036,7 +1036,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) {
t.Fatalf("unable to reopen channeldb: %v", err) t.Fatalf("unable to reopen channeldb: %v", err)
} }
s2, err := initSwitchWithDB(cdb2) s2, err := initSwitchWithDB(testStartingHeight, cdb2)
if err != nil { if err != nil {
t.Fatalf("unable reinit switch: %v", err) t.Fatalf("unable reinit switch: %v", err)
} }
@ -1129,7 +1129,7 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) {
t.Fatalf("unable to reopen channeldb: %v", err) t.Fatalf("unable to reopen channeldb: %v", err)
} }
s3, err := initSwitchWithDB(cdb3) s3, err := initSwitchWithDB(testStartingHeight, cdb3)
if err != nil { if err != nil {
t.Fatalf("unable reinit switch: %v", err) t.Fatalf("unable reinit switch: %v", err)
} }
@ -1167,16 +1167,16 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
var packet *htlcPacket var packet *htlcPacket
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
bobPeer, err := newMockServer(t, "bob", nil) bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create bob server: %v", err) t.Fatalf("unable to create bob server: %v", err)
} }
s, err := initSwitchWithDB(nil) s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -1237,12 +1237,12 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) {
// We'll create a single link for this test, marking it as being unable // We'll create a single link for this test, marking it as being unable
// to forward form the get go. // to forward form the get go.
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
s, err := initSwitchWithDB(nil) s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -1289,16 +1289,16 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) {
func TestSwitchCancel(t *testing.T) { func TestSwitchCancel(t *testing.T) {
t.Parallel() t.Parallel()
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
bobPeer, err := newMockServer(t, "bob", nil) bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create bob server: %v", err) t.Fatalf("unable to create bob server: %v", err)
} }
s, err := initSwitchWithDB(nil) s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -1402,16 +1402,16 @@ func TestSwitchAddSamePayment(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs() chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
bobPeer, err := newMockServer(t, "bob", nil) bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create bob server: %v", err) t.Fatalf("unable to create bob server: %v", err)
} }
s, err := initSwitchWithDB(nil) s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -1561,12 +1561,12 @@ func TestSwitchAddSamePayment(t *testing.T) {
func TestSwitchSendPayment(t *testing.T) { func TestSwitchSendPayment(t *testing.T) {
t.Parallel() t.Parallel()
alicePeer, err := newMockServer(t, "alice", nil) alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
s, err := initSwitchWithDB(nil) s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil { if err != nil {
t.Fatalf("unable to init switch: %v", err) t.Fatalf("unable to init switch: %v", err)
} }
@ -1805,8 +1805,6 @@ func TestMultiHopPaymentForwardingEvents(t *testing.T) {
} }
} }
time.Sleep(time.Millisecond * 200)
// With all 10 payments sent. We'll now manually stop each of the // With all 10 payments sent. We'll now manually stop each of the
// switches so we can examine their end state. // switches so we can examine their end state.
n.stop() n.stop()

@ -17,7 +17,6 @@ import (
"github.com/btcsuite/fastsha256" "github.com/btcsuite/fastsha256"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
@ -569,22 +568,13 @@ func generateRoute(hops ...ForwardingInfo) ([lnwire.OnionPacketSize]byte, error)
type threeHopNetwork struct { type threeHopNetwork struct {
aliceServer *mockServer aliceServer *mockServer
aliceChannelLink *channelLink aliceChannelLink *channelLink
aliceBlockEpoch chan *chainntnfs.BlockEpoch
aliceTicker *time.Ticker
firstBobChannelLink *channelLink
bobFirstBlockEpoch chan *chainntnfs.BlockEpoch
firstBobTicker *time.Ticker
bobServer *mockServer bobServer *mockServer
firstBobChannelLink *channelLink
secondBobChannelLink *channelLink secondBobChannelLink *channelLink
bobSecondBlockEpoch chan *chainntnfs.BlockEpoch
secondBobTicker *time.Ticker
carolChannelLink *channelLink
carolServer *mockServer carolServer *mockServer
carolBlockEpoch chan *chainntnfs.BlockEpoch carolChannelLink *channelLink
carolTicker *time.Ticker
feeEstimator *mockFeeEstimator feeEstimator *mockFeeEstimator
@ -762,11 +752,6 @@ func (n *threeHopNetwork) stop() {
done <- struct{}{} done <- struct{}{}
}() }()
n.aliceTicker.Stop()
n.firstBobTicker.Stop()
n.secondBobTicker.Stop()
n.carolTicker.Stop()
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
<-done <-done
} }
@ -858,15 +843,15 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
carolDb := carolChannel.State().Db carolDb := carolChannel.State().Db
// Create three peers/servers. // Create three peers/servers.
aliceServer, err := newMockServer(t, "alice", aliceDb) aliceServer, err := newMockServer(t, "alice", startingHeight, aliceDb)
if err != nil { if err != nil {
t.Fatalf("unable to create alice server: %v", err) t.Fatalf("unable to create alice server: %v", err)
} }
bobServer, err := newMockServer(t, "bob", bobDb) bobServer, err := newMockServer(t, "bob", startingHeight, bobDb)
if err != nil { if err != nil {
t.Fatalf("unable to create bob server: %v", err) t.Fatalf("unable to create bob server: %v", err)
} }
carolServer, err := newMockServer(t, "carol", carolDb) carolServer, err := newMockServer(t, "carol", startingHeight, carolDb)
if err != nil { if err != nil {
t.Fatalf("unable to create carol server: %v", err) t.Fatalf("unable to create carol server: %v", err)
} }
@ -882,6 +867,12 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
quit: make(chan struct{}), quit: make(chan struct{}),
} }
const (
batchTimeout = 50 * time.Millisecond
fwdPkgTimeout = 5 * time.Second
feeUpdateTimeout = 30 * time.Minute
)
pCache := &mockPreimageCache{ pCache := &mockPreimageCache{
// hash -> preimage // hash -> preimage
preimageMap: make(map[[32]byte][]byte), preimageMap: make(map[[32]byte][]byte),
@ -894,13 +885,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
} }
obfuscator := NewMockObfuscator() obfuscator := NewMockObfuscator()
aliceEpochChan := make(chan *chainntnfs.BlockEpoch)
aliceEpoch := &chainntnfs.BlockEpochEvent{
Epochs: aliceEpochChan,
Cancel: func() {
},
}
aliceTicker := time.NewTicker(50 * time.Millisecond)
aliceChannelLink := NewChannelLink( aliceChannelLink := NewChannelLink(
ChannelLinkConfig{ ChannelLinkConfig{
Switch: aliceServer.htlcSwitch, Switch: aliceServer.htlcSwitch,
@ -915,7 +899,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
}, },
FetchLastChannelUpdate: mockGetChanUpdateMessage, FetchLastChannelUpdate: mockGetChanUpdateMessage,
Registry: aliceServer.registry, Registry: aliceServer.registry,
BlockEpochs: aliceEpoch,
FeeEstimator: feeEstimator, FeeEstimator: feeEstimator,
PreimageCache: pCache, PreimageCache: pCache,
UpdateContractSignals: func(*contractcourt.ContractSignals) error { UpdateContractSignals: func(*contractcourt.ContractSignals) error {
@ -923,12 +906,14 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
}, },
ChainEvents: &contractcourt.ChainEventSubscription{}, ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true, SyncStates: true,
BatchTicker: &mockTicker{aliceTicker.C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(5 * time.Second).C},
BatchSize: 10, BatchSize: 10,
BatchTicker: &mockTicker{time.NewTicker(batchTimeout).C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(fwdPkgTimeout).C},
MinFeeUpdateTimeout: feeUpdateTimeout,
MaxFeeUpdateTimeout: feeUpdateTimeout,
OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {},
}, },
aliceChannel, aliceChannel,
startingHeight,
) )
if err := aliceServer.htlcSwitch.AddLink(aliceChannelLink); err != nil { if err := aliceServer.htlcSwitch.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice channel link: %v", err) t.Fatalf("unable to add alice channel link: %v", err)
@ -943,13 +928,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
} }
}() }()
bobFirstEpochChan := make(chan *chainntnfs.BlockEpoch)
bobFirstEpoch := &chainntnfs.BlockEpochEvent{
Epochs: bobFirstEpochChan,
Cancel: func() {
},
}
firstBobTicker := time.NewTicker(50 * time.Millisecond)
firstBobChannelLink := NewChannelLink( firstBobChannelLink := NewChannelLink(
ChannelLinkConfig{ ChannelLinkConfig{
Switch: bobServer.htlcSwitch, Switch: bobServer.htlcSwitch,
@ -964,7 +942,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
}, },
FetchLastChannelUpdate: mockGetChanUpdateMessage, FetchLastChannelUpdate: mockGetChanUpdateMessage,
Registry: bobServer.registry, Registry: bobServer.registry,
BlockEpochs: bobFirstEpoch,
FeeEstimator: feeEstimator, FeeEstimator: feeEstimator,
PreimageCache: pCache, PreimageCache: pCache,
UpdateContractSignals: func(*contractcourt.ContractSignals) error { UpdateContractSignals: func(*contractcourt.ContractSignals) error {
@ -972,12 +949,14 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
}, },
ChainEvents: &contractcourt.ChainEventSubscription{}, ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true, SyncStates: true,
BatchTicker: &mockTicker{firstBobTicker.C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(5 * time.Second).C},
BatchSize: 10, BatchSize: 10,
BatchTicker: &mockTicker{time.NewTicker(batchTimeout).C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(fwdPkgTimeout).C},
MinFeeUpdateTimeout: feeUpdateTimeout,
MaxFeeUpdateTimeout: feeUpdateTimeout,
OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {},
}, },
firstBobChannel, firstBobChannel,
startingHeight,
) )
if err := bobServer.htlcSwitch.AddLink(firstBobChannelLink); err != nil { if err := bobServer.htlcSwitch.AddLink(firstBobChannelLink); err != nil {
t.Fatalf("unable to add first bob channel link: %v", err) t.Fatalf("unable to add first bob channel link: %v", err)
@ -992,13 +971,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
} }
}() }()
bobSecondEpochChan := make(chan *chainntnfs.BlockEpoch)
bobSecondEpoch := &chainntnfs.BlockEpochEvent{
Epochs: bobSecondEpochChan,
Cancel: func() {
},
}
secondBobTicker := time.NewTicker(50 * time.Millisecond)
secondBobChannelLink := NewChannelLink( secondBobChannelLink := NewChannelLink(
ChannelLinkConfig{ ChannelLinkConfig{
Switch: bobServer.htlcSwitch, Switch: bobServer.htlcSwitch,
@ -1013,7 +985,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
}, },
FetchLastChannelUpdate: mockGetChanUpdateMessage, FetchLastChannelUpdate: mockGetChanUpdateMessage,
Registry: bobServer.registry, Registry: bobServer.registry,
BlockEpochs: bobSecondEpoch,
FeeEstimator: feeEstimator, FeeEstimator: feeEstimator,
PreimageCache: pCache, PreimageCache: pCache,
UpdateContractSignals: func(*contractcourt.ContractSignals) error { UpdateContractSignals: func(*contractcourt.ContractSignals) error {
@ -1021,12 +992,14 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
}, },
ChainEvents: &contractcourt.ChainEventSubscription{}, ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true, SyncStates: true,
BatchTicker: &mockTicker{secondBobTicker.C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(5 * time.Second).C},
BatchSize: 10, BatchSize: 10,
BatchTicker: &mockTicker{time.NewTicker(batchTimeout).C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(fwdPkgTimeout).C},
MinFeeUpdateTimeout: feeUpdateTimeout,
MaxFeeUpdateTimeout: feeUpdateTimeout,
OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {},
}, },
secondBobChannel, secondBobChannel,
startingHeight,
) )
if err := bobServer.htlcSwitch.AddLink(secondBobChannelLink); err != nil { if err := bobServer.htlcSwitch.AddLink(secondBobChannelLink); err != nil {
t.Fatalf("unable to add second bob channel link: %v", err) t.Fatalf("unable to add second bob channel link: %v", err)
@ -1041,13 +1014,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
} }
}() }()
carolBlockEpoch := make(chan *chainntnfs.BlockEpoch)
carolEpoch := &chainntnfs.BlockEpochEvent{
Epochs: bobSecondEpochChan,
Cancel: func() {
},
}
carolTicker := time.NewTicker(50 * time.Millisecond)
carolChannelLink := NewChannelLink( carolChannelLink := NewChannelLink(
ChannelLinkConfig{ ChannelLinkConfig{
Switch: carolServer.htlcSwitch, Switch: carolServer.htlcSwitch,
@ -1062,7 +1028,6 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
}, },
FetchLastChannelUpdate: mockGetChanUpdateMessage, FetchLastChannelUpdate: mockGetChanUpdateMessage,
Registry: carolServer.registry, Registry: carolServer.registry,
BlockEpochs: carolEpoch,
FeeEstimator: feeEstimator, FeeEstimator: feeEstimator,
PreimageCache: pCache, PreimageCache: pCache,
UpdateContractSignals: func(*contractcourt.ContractSignals) error { UpdateContractSignals: func(*contractcourt.ContractSignals) error {
@ -1070,12 +1035,14 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
}, },
ChainEvents: &contractcourt.ChainEventSubscription{}, ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true, SyncStates: true,
BatchTicker: &mockTicker{carolTicker.C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(5 * time.Second).C},
BatchSize: 10, BatchSize: 10,
BatchTicker: &mockTicker{time.NewTicker(batchTimeout).C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(fwdPkgTimeout).C},
MinFeeUpdateTimeout: feeUpdateTimeout,
MaxFeeUpdateTimeout: feeUpdateTimeout,
OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {},
}, },
carolChannel, carolChannel,
startingHeight,
) )
if err := carolServer.htlcSwitch.AddLink(carolChannelLink); err != nil { if err := carolServer.htlcSwitch.AddLink(carolChannelLink); err != nil {
t.Fatalf("unable to add carol channel link: %v", err) t.Fatalf("unable to add carol channel link: %v", err)
@ -1093,22 +1060,13 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
return &threeHopNetwork{ return &threeHopNetwork{
aliceServer: aliceServer, aliceServer: aliceServer,
aliceChannelLink: aliceChannelLink.(*channelLink), aliceChannelLink: aliceChannelLink.(*channelLink),
aliceBlockEpoch: aliceEpochChan,
aliceTicker: aliceTicker,
firstBobChannelLink: firstBobChannelLink.(*channelLink),
bobFirstBlockEpoch: bobFirstEpochChan,
firstBobTicker: firstBobTicker,
bobServer: bobServer, bobServer: bobServer,
firstBobChannelLink: firstBobChannelLink.(*channelLink),
secondBobChannelLink: secondBobChannelLink.(*channelLink), secondBobChannelLink: secondBobChannelLink.(*channelLink),
bobSecondBlockEpoch: bobSecondEpochChan,
secondBobTicker: secondBobTicker,
carolChannelLink: carolChannelLink.(*channelLink),
carolServer: carolServer, carolServer: carolServer,
carolBlockEpoch: carolBlockEpoch, carolChannelLink: carolChannelLink.(*channelLink),
carolTicker: carolTicker,
feeEstimator: feeEstimator, feeEstimator: feeEstimator,
globalPolicy: globalPolicy, globalPolicy: globalPolicy,

15
peer.go

@ -1,6 +1,7 @@
package main package main
import ( import (
"bytes"
"container/list" "container/list"
"fmt" "fmt"
"net" "net"
@ -9,14 +10,11 @@ import (
"time" "time"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/brontide"
"github.com/lightningnetwork/lnd/contractcourt"
"bytes"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/brontide"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
@ -537,7 +535,6 @@ func (p *peer) addLink(chanPoint *wire.OutPoint,
ForwardPackets: p.server.htlcSwitch.ForwardPackets, ForwardPackets: p.server.htlcSwitch.ForwardPackets,
FwrdingPolicy: *forwardingPolicy, FwrdingPolicy: *forwardingPolicy,
FeeEstimator: p.server.cc.feeEstimator, FeeEstimator: p.server.cc.feeEstimator,
BlockEpochs: blockEpoch,
PreimageCache: p.server.witnessBeacon, PreimageCache: p.server.witnessBeacon,
ChainEvents: chainEvents, ChainEvents: chainEvents,
UpdateContractSignals: func(signals *contractcourt.ContractSignals) error { UpdateContractSignals: func(signals *contractcourt.ContractSignals) error {
@ -553,9 +550,11 @@ func (p *peer) addLink(chanPoint *wire.OutPoint,
time.NewTicker(time.Minute)), time.NewTicker(time.Minute)),
BatchSize: 10, BatchSize: 10,
UnsafeReplay: cfg.UnsafeReplay, UnsafeReplay: cfg.UnsafeReplay,
MinFeeUpdateTimeout: htlcswitch.DefaultMinLinkFeeUpdateTimeout,
MaxFeeUpdateTimeout: htlcswitch.DefaultMaxLinkFeeUpdateTimeout,
} }
link := htlcswitch.NewChannelLink(linkCfg, lnChan,
uint32(currentHeight)) link := htlcswitch.NewChannelLink(linkCfg, lnChan)
// With the channel link created, we'll now notify the htlc switch so // With the channel link created, we'll now notify the htlc switch so
// this channel can be used to dispatch local payments and also // this channel can be used to dispatch local payments and also

@ -284,6 +284,11 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
debugPre[:], debugHash[:]) debugPre[:], debugHash[:])
} }
_, currentHeight, err := s.cc.chainIO.GetBestBlock()
if err != nil {
return nil, err
}
s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{ s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{
DB: chanDB, DB: chanDB,
SelfKey: s.identityPriv.PubKey(), SelfKey: s.identityPriv.PubKey(),
@ -313,7 +318,8 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
SwitchPackager: channeldb.NewSwitchPackager(), SwitchPackager: channeldb.NewSwitchPackager(),
ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter, ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter,
FetchLastChannelUpdate: fetchLastChanUpdate(s, serializedPubKey), FetchLastChannelUpdate: fetchLastChanUpdate(s, serializedPubKey),
}) Notifier: s.cc.chainNotifier,
}, uint32(currentHeight))
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -341,10 +341,17 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
breachArbiter: breachArbiter, breachArbiter: breachArbiter,
chainArb: chainArb, chainArb: chainArb,
} }
_, currentHeight, err := s.cc.chainIO.GetBestBlock()
if err != nil {
return nil, nil, nil, nil, err
}
htlcSwitch, err := htlcswitch.New(htlcswitch.Config{ htlcSwitch, err := htlcswitch.New(htlcswitch.Config{
DB: dbAlice, DB: dbAlice,
SwitchPackager: channeldb.NewSwitchPackager(), SwitchPackager: channeldb.NewSwitchPackager(),
}) Notifier: notifier,
}, uint32(currentHeight))
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }