Merge pull request #761 from cfromknecht/switch-persistence

Switch Persistence [ALL]: Forwarding Packages + Sphinx Replay Protection + Circuit Persistence
This commit is contained in:
Olaoluwa Osuntokun 2018-03-09 22:40:02 -08:00 committed by GitHub
commit bfa76bad49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 7175 additions and 1000 deletions

@ -1401,6 +1401,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
RemoteChanCfg: bobCfg,
IdentityPub: aliceKeyPub,
FundingOutpoint: *prevOut,
ShortChanID: shortChanID,
ChanType: channeldb.SingleFunder,
IsInitiator: true,
Capacity: channelCapacity,
@ -1417,6 +1418,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
RemoteChanCfg: aliceCfg,
IdentityPub: bobKeyPub,
FundingOutpoint: *prevOut,
ShortChanID: shortChanID,
ChanType: channeldb.SingleFunder,
IsInitiator: false,
Capacity: channelCapacity,

@ -170,6 +170,8 @@ type config struct {
DebugHTLC bool `long:"debughtlc" description:"Activate the debug htlc mode. With the debug HTLC mode, all payments sent use a pre-determined R-Hash. Additionally, all HTLCs sent to a node with the debug HTLC R-Hash are immediately settled in the next available state transition."`
HodlHTLC bool `long:"hodlhtlc" description:"Activate the hodl HTLC mode. With hodl HTLC mode, all incoming HTLCs will be accepted by the receiving node, but no attempt will be made to settle the payment with the sender."`
UnsafeDisconnect bool `long:"unsafe-disconnect" description:"Allows the rpcserver to intentionally disconnect from peers with open channels. USED FOR TESTING ONLY."`
UnsafeReplay bool `long:"unsafe-replay" description:"Causes a link to replay the adds on its commitment txn after starting up, this enables testing of the sphinx replay logic."`
MaxPendingChannels int `long:"maxpendingchannels" description:"The maximum number of incoming pending channels permitted per peer."`
Bitcoin *chainConfig `group:"Bitcoin" namespace:"bitcoin"`

@ -1,184 +1,229 @@
package htlcswitch
import (
"fmt"
"sync"
"encoding/binary"
"io"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
)
// PaymentCircuit is used by the HTLC switch subsystem to determine the
// backwards path for the settle/fail HTLC messages. A payment circuit
// will be created once a channel link forwards the HTLC add request and
// removed when we receive a settle/fail HTLC message.
// EmptyCircuitKey is a default value for an outgoing circuit key returned when
// a circuit's keystone has not been set. Note that this value is invalid for
// use as a keystone, since the outgoing channel id can never be equal to
// sourceHop.
var EmptyCircuitKey CircuitKey
// CircuitKey is a tuple of channel ID and HTLC ID, used to uniquely identify
// HTLCs in a circuit. Circuits are identified primarily by the circuit key of
// the incoming HTLC. However, a circuit may also be referenced by its outgoing
// circuit key after the HTLC has been forwarded via the outgoing link.
type CircuitKey = channeldb.CircuitKey
// PaymentCircuit is used by the switch as placeholder between when the
// switch makes a forwarding decision and the outgoing link determines the
// proper HTLC ID for the local log. After the outgoing HTLC ID has been
// determined, the half circuit will be converted into a full PaymentCircuit.
type PaymentCircuit struct {
// AddRef is the forward reference of the Add update in the incoming
// link's forwarding package. This value is set on the htlcPacket of the
// returned settle/fail so that it can be removed from disk.
AddRef channeldb.AddRef
// Incoming is the circuit key identifying the incoming channel and htlc
// index from which this ADD originates.
Incoming CircuitKey
// Outgoing is the circuit key identifying the outgoing channel, and the
// HTLC index that was used to forward the ADD. It will be nil if this
// circuit's keystone has not been set.
Outgoing *CircuitKey
// PaymentHash used as unique identifier of payment.
PaymentHash [32]byte
// IncomingChanID identifies the channel from which add HTLC request
// came and to which settle/fail HTLC request will be returned back.
// Once the switch forwards the settle/fail message to the src the
// circuit is considered to be completed.
IncomingChanID lnwire.ShortChannelID
// IncomingAmount is the value of the HTLC from the incoming link.
IncomingAmount lnwire.MilliSatoshi
// IncomingHTLCID is the ID in the update_add_htlc message we received
// from the incoming channel, which will be included in any settle/fail
// messages we send back.
IncomingHTLCID uint64
// IncomingAmt is the value of the incoming HTLC. If we take this and
// subtract it from the OutgoingAmt, then we'll compute the total fee
// attached to this payment circuit.
IncomingAmt lnwire.MilliSatoshi
// OutgoingChanID identifies the channel to which we propagate the HTLC
// add update and from which we are expecting to receive HTLC
// settle/fail request back.
OutgoingChanID lnwire.ShortChannelID
// OutgoingHTLCID is the ID in the update_add_htlc message we sent to
// the outgoing channel.
OutgoingHTLCID uint64
// OutgoingAmt is the value of the outgoing HTLC. If we subtract this
// from the IncomingAmt, then we'll compute the total fee attached to
// this payment circuit.
OutgoingAmt lnwire.MilliSatoshi
// OutgoingAmount specifies the value of the HTLC leaving the switch,
// either as a payment or forwarded amount.
OutgoingAmount lnwire.MilliSatoshi
// ErrorEncrypter is used to re-encrypt the onion failure before
// sending it back to the originator of the payment.
ErrorEncrypter ErrorEncrypter
// LoadedFromDisk is set true for any circuits loaded after the circuit
// map is reloaded from disk.
//
// NOTE: This value is determined implicitly during a restart. It is not
// persisted, and should never be set outside the circuit map.
LoadedFromDisk bool
}
// circuitKey is a channel ID, HTLC ID tuple used as an identifying key for a
// payment circuit. The circuit map is keyed with the identifier for the
// outgoing HTLC
type circuitKey struct {
chanID lnwire.ShortChannelID
htlcID uint64
// HasKeystone returns true if an outgoing link has assigned this circuit's
// outgoing circuit key.
func (c *PaymentCircuit) HasKeystone() bool {
return c.Outgoing != nil
}
// String returns a string representation of the circuitKey.
func (k *circuitKey) String() string {
return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.chanID, k.htlcID)
}
// newPaymentCircuit initializes a payment circuit on the heap using the payment
// hash and an in-memory htlc packet.
func newPaymentCircuit(hash *[32]byte, pkt *htlcPacket) *PaymentCircuit {
var addRef channeldb.AddRef
if pkt.sourceRef != nil {
addRef = *pkt.sourceRef
}
// CircuitMap is a data structure that implements thread safe storage of
// circuit routing information. The switch consults a circuit map to determine
// where to forward HTLC update messages. Each circuit is stored with its
// outgoing HTLC as the primary key because, each offered HTLC has at most one
// received HTLC, but there may be multiple offered or received HTLCs with the
// same payment hash. Circuits are also indexed to provide fast lookups by
// payment hash.
//
// TODO(andrew.shvv) make it persistent
type CircuitMap struct {
mtx sync.RWMutex
circuits map[circuitKey]*PaymentCircuit
hashIndex map[[32]byte]map[PaymentCircuit]struct{}
}
// NewCircuitMap creates a new instance of the CircuitMap.
func NewCircuitMap() *CircuitMap {
return &CircuitMap{
circuits: make(map[circuitKey]*PaymentCircuit),
hashIndex: make(map[[32]byte]map[PaymentCircuit]struct{}),
return &PaymentCircuit{
AddRef: addRef,
Incoming: CircuitKey{
ChanID: pkt.incomingChanID,
HtlcID: pkt.incomingHTLCID,
},
PaymentHash: *hash,
IncomingAmount: pkt.incomingAmount,
OutgoingAmount: pkt.amount,
ErrorEncrypter: pkt.obfuscator,
}
}
// LookupByHTLC looks up the payment circuit by the outgoing channel and HTLC
// IDs. Returns nil if there is no such circuit.
func (cm *CircuitMap) LookupByHTLC(chanID lnwire.ShortChannelID, htlcID uint64) *PaymentCircuit {
cm.mtx.RLock()
key := circuitKey{
chanID: chanID,
htlcID: htlcID,
// makePaymentCircuit initalizes a payment circuit on the stack using the
// payment hash and an in-memory htlc packet.
func makePaymentCircuit(hash *[32]byte, pkt *htlcPacket) PaymentCircuit {
var addRef channeldb.AddRef
if pkt.sourceRef != nil {
addRef = *pkt.sourceRef
}
circuit := cm.circuits[key]
cm.mtx.RUnlock()
return circuit
return PaymentCircuit{
AddRef: addRef,
Incoming: CircuitKey{
ChanID: pkt.incomingChanID,
HtlcID: pkt.incomingHTLCID,
},
PaymentHash: *hash,
IncomingAmount: pkt.incomingAmount,
OutgoingAmount: pkt.amount,
ErrorEncrypter: pkt.obfuscator,
}
}
// LookupByPaymentHash looks up and returns any payment circuits with a given
// payment hash.
func (cm *CircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit {
cm.mtx.RLock()
var circuits []*PaymentCircuit
if circuitSet, ok := cm.hashIndex[hash]; ok {
circuits = make([]*PaymentCircuit, 0, len(circuitSet))
for circuit := range circuitSet {
circuits = append(circuits, &circuit)
}
// Encode writes a PaymentCircuit to the provided io.Writer.
func (c *PaymentCircuit) Encode(w io.Writer) error {
if err := c.AddRef.Encode(w); err != nil {
return err
}
cm.mtx.RUnlock()
return circuits
if err := c.Incoming.Encode(w); err != nil {
return err
}
if _, err := w.Write(c.PaymentHash[:]); err != nil {
return err
}
var scratch [8]byte
binary.BigEndian.PutUint64(scratch[:], uint64(c.IncomingAmount))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
binary.BigEndian.PutUint64(scratch[:], uint64(c.OutgoingAmount))
if _, err := w.Write(scratch[:]); err != nil {
return err
}
// Defaults to EncrypterTypeNone.
var encrypterType EncrypterType
if c.ErrorEncrypter != nil {
encrypterType = c.ErrorEncrypter.Type()
}
err := binary.Write(w, binary.BigEndian, encrypterType)
if err != nil {
return err
}
// Skip encoding of error encrypter if this half add does not have one.
if encrypterType == EncrypterTypeNone {
return nil
}
return c.ErrorEncrypter.Encode(w)
}
// Add adds a new active payment circuit to the CircuitMap.
func (cm *CircuitMap) Add(circuit *PaymentCircuit) error {
cm.mtx.Lock()
key := circuitKey{
chanID: circuit.OutgoingChanID,
htlcID: circuit.OutgoingHTLCID,
// Decode reads a PaymentCircuit from the provided io.Reader.
func (c *PaymentCircuit) Decode(r io.Reader) error {
if err := c.AddRef.Decode(r); err != nil {
return err
}
cm.circuits[key] = circuit
// Add circuit to the hash index.
if _, ok := cm.hashIndex[circuit.PaymentHash]; !ok {
cm.hashIndex[circuit.PaymentHash] = make(map[PaymentCircuit]struct{})
if err := c.Incoming.Decode(r); err != nil {
return err
}
cm.hashIndex[circuit.PaymentHash][*circuit] = struct{}{}
cm.mtx.Unlock()
return nil
if _, err := io.ReadFull(r, c.PaymentHash[:]); err != nil {
return err
}
var scratch [8]byte
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return err
}
c.IncomingAmount = lnwire.MilliSatoshi(
binary.BigEndian.Uint64(scratch[:]))
if _, err := io.ReadFull(r, scratch[:]); err != nil {
return err
}
c.OutgoingAmount = lnwire.MilliSatoshi(
binary.BigEndian.Uint64(scratch[:]))
// Read the encrypter type used for this circuit.
var encrypterType EncrypterType
err := binary.Read(r, binary.BigEndian, &encrypterType)
if err != nil {
return err
}
switch encrypterType {
case EncrypterTypeNone:
// No encrypter was provided, such as when the payment is
// locally initiated.
return nil
case EncrypterTypeSphinx:
// Sphinx encrypter was used as this is a forwarded HTLC.
c.ErrorEncrypter = NewSphinxErrorEncrypter()
case EncrypterTypeMock:
// Test encrypter.
c.ErrorEncrypter = NewMockObfuscator()
default:
return UnknownEncrypterType(encrypterType)
}
return c.ErrorEncrypter.Decode(r)
}
// Remove destroys the target circuit by removing it from the circuit map.
func (cm *CircuitMap) Remove(chanID lnwire.ShortChannelID, htlcID uint64) error {
cm.mtx.Lock()
defer cm.mtx.Unlock()
// Look up circuit so that pointer can be matched in the hash index.
key := circuitKey{
chanID: chanID,
htlcID: htlcID,
}
circuit, found := cm.circuits[key]
if !found {
return errors.Errorf("Can't find circuit for HTLC %v", key)
}
delete(cm.circuits, key)
// Remove circuit from hash index.
circuitsWithHash, ok := cm.hashIndex[circuit.PaymentHash]
if !ok {
return errors.Errorf("Can't find circuit in hash index for HTLC %v",
key)
}
if _, ok = circuitsWithHash[*circuit]; !ok {
return errors.Errorf("Can't find circuit in hash index for HTLC %v",
key)
}
delete(circuitsWithHash, *circuit)
if len(circuitsWithHash) == 0 {
delete(cm.hashIndex, circuit.PaymentHash)
}
return nil
// InKey returns the primary identifier for the circuit corresponding to the
// incoming HTLC.
func (c *PaymentCircuit) InKey() CircuitKey {
return c.Incoming
}
// pending returns number of circuits which are waiting for to be completed
// (settle/fail responses to be received).
func (cm *CircuitMap) pending() int {
cm.mtx.RLock()
count := len(cm.circuits)
cm.mtx.RUnlock()
return count
// OutKey returns the keystone identifying the outgoing link and HTLC ID. If the
// circuit hasn't been completed, this method returns an EmptyKeystone, which is
// an invalid outgoing circuit key. Only call this method if HasKeystone returns
// true.
func (c *PaymentCircuit) OutKey() CircuitKey {
if c.Outgoing != nil {
return *c.Outgoing
}
return EmptyCircuitKey
}

846
htlcswitch/circuit_map.go Normal file

@ -0,0 +1,846 @@
package htlcswitch
import (
"bytes"
"fmt"
"sync"
"github.com/boltdb/bolt"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrCorruptedCircuitMap indicates that the on-disk bucketing structure
// has altered since the circuit map instance was initialized.
ErrCorruptedCircuitMap = errors.New("circuit map has been corrupted")
// ErrCircuitNotInHashIndex indicates that a particular circuit did not
// appear in the in-memory hash index.
ErrCircuitNotInHashIndex = errors.New("payment circuit not found in " +
"hash index")
// ErrUnknownCircuit signals that circuit could not be removed from the
// map because it was not found.
ErrUnknownCircuit = errors.New("unknown payment circuit")
// ErrCircuitClosing signals that an htlc has already closed this
// circuit in-memory.
ErrCircuitClosing = errors.New("circuit has already been closed")
// ErrDuplicateCircuit signals that this circuit was previously
// added.
ErrDuplicateCircuit = errors.New("duplicate circuit add")
// ErrUnknownKeystone signals that no circuit was found using the
// outgoing circuit key.
ErrUnknownKeystone = errors.New("unknown circuit keystone")
// ErrDuplicateKeystone signals that this circuit was previously
// assigned a keystone.
ErrDuplicateKeystone = errors.New("cannot add duplicate keystone")
)
// CircuitModifier is a common interface used by channel links to modify the
// contents of the circuit map maintained by the switch.
type CircuitModifier interface {
// OpenCircuits preemptively records a batch keystones that will mark
// currently pending circuits as open. These changes can be rolled back
// on restart if the outgoing Adds do not make it into a commitment txn.
OpenCircuits(...Keystone) error
// TrimOpenCircuits removes a channel's open channels with htlc indexes
// above `start`.
TrimOpenCircuits(chanID lnwire.ShortChannelID, start uint64) error
// DeleteCircuits removes the incoming circuit key to remove all
// persistent references to a circuit. Returns a ErrUnknownCircuit if
// any of the incoming keys are not known.
DeleteCircuits(inKeys ...CircuitKey) error
}
// CircuitFwdActions represents the forwarding decision made by the circuit map,
// and is returned from CommitCircuits. The sequence of circuits provided to
// CommitCircuits is split into three subsequences, allowing the caller to do an
// in-order scan, comparing the head of each subsequence, to determine the
// decision made by the circuit map.
type CircuitFwdActions struct {
// Adds is the subsequence of circuits that were successfully committed
// in the circuit map.
Adds []*PaymentCircuit
// Drops is the subsequence of circuits for which no action should be
// done.
Drops []*PaymentCircuit
// Fails is the subsequence of circuits that should be failed back by
// the calling link.
Fails []*PaymentCircuit
}
// CircuitMap is an interface for managing the construction and teardown of
// payment circuits used by the switch.
type CircuitMap interface {
CircuitModifier
// CommitCircuits attempts to add the given circuits to the circuit
// map. The list of circuits is split into three distinct subsequences,
// corresponding to adds, drops, and fails. Adds should be forwarded to
// the switch, while fails should be failed back locally within the
// calling link.
CommitCircuits(circuit ...*PaymentCircuit) (*CircuitFwdActions, error)
// CloseCircuit marks the circuit identified by `outKey` as closing
// in-memory, which prevents duplicate settles/fails from completing an
// open circuit twice.
CloseCircuit(outKey CircuitKey) (*PaymentCircuit, error)
// FailCircuit is used by locally failed HTLCs to mark the circuit
// identified by `inKey` as closing in-memory, which prevents duplicate
// settles/fails from being accepted for the same circuit.
FailCircuit(inKey CircuitKey) (*PaymentCircuit, error)
// LookupCircuit queries the circuit map for the circuit identified by
// inKey.
LookupCircuit(inKey CircuitKey) *PaymentCircuit
// LookupOpenCircuit queries the circuit map for a circuit identified by
// its outgoing circuit key.
LookupOpenCircuit(outKey CircuitKey) *PaymentCircuit
// LookupByPaymentHash queries the circuit map and returns all open
// circuits that use the given payment hash.
LookupByPaymentHash(hash [32]byte) []*PaymentCircuit
// NumPending returns the total number of active circuits added by
// CommitCircuits.
NumPending() int
// NumOpen returns the number of circuits with HTLCs that have been
// forwarded via an outgoing link.
NumOpen() int
}
var (
// circuitAddKey is the key used to retrieve the bucket containing
// payment circuits. A circuit records information about how to return a
// packet to the source link, potentially including an error encrypter
// for applying this hop's encryption to the payload in the reverse
// direction.
circuitAddKey = []byte("circuit-adds")
// circuitKeystoneKey is used to retrieve the bucket containing circuit
// keystones, which are set in place once a forwarded packet is assigned
// an index on an outgoing commitment txn.
circuitKeystoneKey = []byte("circuit-keystones")
)
// circuitMap is a data structure that implements thread safe, persistent
// storage of circuit routing information. The switch consults a circuit map to
// determine where to forward returning HTLC update messages. Circuits are
// always identifiable by their incoming CircuitKey, in addition to their
// outgoing CircuitKey if the circuit is fully-opened.
type circuitMap struct {
// db provides the persistent storage engine for the circuit map.
// TODO(conner): create abstraction to allow for the substitution of
// other persistence engines.
db *channeldb.DB
mtx sync.RWMutex
// pending is an in-memory mapping of all half payment circuits, and
// is kept in sync with the on-disk contents of the circuit map.
pending map[CircuitKey]*PaymentCircuit
// opened is an in-memory mapping of all full payment circuits, which is
// also synchronized with the persistent state of the circuit map.
opened map[CircuitKey]*PaymentCircuit
// closed is an in-memory set of circuits for which the switch has
// received a settle or fail. This precedes the actual deletion of a
// circuit from disk.
closed map[CircuitKey]struct{}
// hashIndex is a volatile index that facilitates fast queries by
// payment hash against the contents of circuits. This index can be
// reconstructed entirely from the set of persisted full circuits on
// startup.
hashIndex map[[32]byte]map[CircuitKey]struct{}
}
// NewCircuitMap creates a new instance of the circuitMap.
func NewCircuitMap(db *channeldb.DB) (CircuitMap, error) {
cm := &circuitMap{
db: db,
}
// Initialize the on-disk buckets used by the circuit map.
if err := cm.initBuckets(); err != nil {
return nil, err
}
// Load any previously persisted circuit into back into memory.
if err := cm.restoreMemState(); err != nil {
return nil, err
}
// Trim any keystones that were not committed in an outgoing commit txn.
//
// NOTE: This operation will be applied to the persistent state of all
// active channels. Therefore, it must be called before any links are
// created to avoid interfering with normal operation.
if err := cm.trimAllOpenCircuits(); err != nil {
return nil, err
}
return cm, nil
}
// initBuckets ensures that the primary buckets used by the circuit are
// initialized so that we can assume their existence after startup.
func (cm *circuitMap) initBuckets() error {
return cm.db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(circuitKeystoneKey)
if err != nil {
return err
}
_, err = tx.CreateBucketIfNotExists(circuitAddKey)
return err
})
}
// restoreMemState loads the contents of the half circuit and full circuit buckets
// from disk and reconstructs the in-memory representation of the circuit map.
// Afterwards, the state of the hash index is reconstructed using the recovered
// set of full circuits.
func (cm *circuitMap) restoreMemState() error {
var (
opened = make(map[CircuitKey]*PaymentCircuit)
pending = make(map[CircuitKey]*PaymentCircuit)
)
if err := cm.db.View(func(tx *bolt.Tx) error {
// Restore any of the circuits persisted in the circuit bucket
// back into memory.
circuitBkt := tx.Bucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}
if err := circuitBkt.ForEach(func(_, v []byte) error {
circuit, err := decodeCircuit(v)
if err != nil {
return err
}
circuit.LoadedFromDisk = true
pending[circuit.Incoming] = circuit
return nil
}); err != nil {
return err
}
// Furthermore, load the keystone bucket and resurrect the
// keystones used in any open circuits.
keystoneBkt := tx.Bucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
if err := keystoneBkt.ForEach(func(k, v []byte) error {
var (
inKey CircuitKey
outKey = &CircuitKey{}
)
// Decode the incoming and outgoing circuit keys.
if err := inKey.SetBytes(v); err != nil {
return err
}
if err := outKey.SetBytes(k); err != nil {
return err
}
// Retrieve the pending circuit, set its keystone, then
// add it to the opened map.
circuit := pending[inKey]
circuit.Outgoing = outKey
opened[*outKey] = circuit
return nil
}); err != nil {
return err
}
return nil
}); err != nil {
return err
}
cm.pending = pending
cm.opened = opened
cm.closed = make(map[CircuitKey]struct{})
// Finally, reconstruct the hash index by running through our set of
// open circuits.
cm.hashIndex = make(map[[32]byte]map[CircuitKey]struct{})
for _, circuit := range opened {
cm.addCircuitToHashIndex(circuit)
}
return nil
}
// decodeCircuit reconstructs an in-memory payment circuit from a byte slice.
// The byte slice is assumed to have been generated by the circuit's Encode
// method.
func decodeCircuit(v []byte) (*PaymentCircuit, error) {
var circuit = &PaymentCircuit{}
circuitReader := bytes.NewReader(v)
if err := circuit.Decode(circuitReader); err != nil {
return nil, err
}
return circuit, nil
}
// trimAllOpenCircuits reads the set of active channels from disk and trims
// keystones for any non-pending channels. This method is intended to be called
// on startup. Each link will also trim it's own circuits upon startup.
//
// NOTE: This operation will be applied to the persistent state of all active
// channels. Therefore, it must be called before any links are created to avoid
// interfering with normal operation.
func (cm *circuitMap) trimAllOpenCircuits() error {
activeChannels, err := cm.db.FetchAllChannels()
if err != nil {
return err
}
for _, activeChannel := range activeChannels {
if activeChannel.IsPending {
continue
}
chanID := activeChannel.ShortChanID
start := activeChannel.LocalCommitment.LocalHtlcIndex
if err := cm.TrimOpenCircuits(chanID, start); err != nil {
return err
}
}
return nil
}
// TrimOpenCircuits removes a channel's keystones above the short chan id's
// highest committed htlc index. This has the effect of returning those circuits
// to a half-open state. Since opening of circuits is done in advance of
// actually committing the Add htlcs into a commitment txn, this allows circuits
// to be opened preemetively, since we can roll them back after any failures.
func (cm *circuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID,
start uint64) error {
var trimmedOutKeys []CircuitKey
// Scan forward from the last unacked htlc id, stopping as soon as we
// don't find any more. Outgoing htlc id's must be assigned in order, so
// there should never be disjoint segments of keystones to trim.
cm.mtx.Lock()
for i := start; ; i++ {
outKey := CircuitKey{
ChanID: chanID,
HtlcID: i,
}
circuit, ok := cm.opened[outKey]
if !ok {
break
}
circuit.Outgoing = nil
delete(cm.opened, outKey)
trimmedOutKeys = append(trimmedOutKeys, outKey)
cm.removeCircuitFromHashIndex(circuit)
}
cm.mtx.Unlock()
if len(trimmedOutKeys) == 0 {
return nil
}
return cm.db.Update(func(tx *bolt.Tx) error {
keystoneBkt := tx.Bucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
for _, outKey := range trimmedOutKeys {
err := keystoneBkt.Delete(outKey.Bytes())
if err != nil {
return err
}
}
return nil
})
}
// LookupByHTLC looks up the payment circuit by the outgoing channel and HTLC
// IDs. Returns nil if there is no such circuit.
func (cm *circuitMap) LookupCircuit(inKey CircuitKey) *PaymentCircuit {
cm.mtx.RLock()
defer cm.mtx.RUnlock()
return cm.pending[inKey]
}
// LookupOpenCircuit searches for the circuit identified by its outgoing circuit
// key.
func (cm *circuitMap) LookupOpenCircuit(outKey CircuitKey) *PaymentCircuit {
cm.mtx.RLock()
defer cm.mtx.RUnlock()
return cm.opened[outKey]
}
// LookupByPaymentHash looks up and returns any payment circuits with a given
// payment hash.
func (cm *circuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit {
cm.mtx.RLock()
defer cm.mtx.RUnlock()
var circuits []*PaymentCircuit
if circuitSet, ok := cm.hashIndex[hash]; ok {
// Iterate over the outgoing circuit keys found with this hash,
// and retrieve the circuit from the opened map.
circuits = make([]*PaymentCircuit, 0, len(circuitSet))
for key := range circuitSet {
if circuit, ok := cm.opened[key]; ok {
circuits = append(circuits, circuit)
}
}
}
return circuits
}
// CommitCircuits accepts any number of circuits and persistently adds them to
// the switch's circuit map. The method returns a list of circuits that had not
// been seen prior by the switch. A link should only forward HTLCs corresponding
// to the returned circuits to the switch.
//
// NOTE: This method uses batched writes to improve performance, gains will only
// be realized if it is called concurrently from separate goroutines.
func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) (
*CircuitFwdActions, error) {
actions := &CircuitFwdActions{}
// If an empty list was passed, return early to avoid grabbing the lock.
if len(circuits) == 0 {
return actions, nil
}
// First, we reconcile the provided circuits with our set of pending
// circuits to construct a set of new circuits that need to be written
// to disk. The circuit's pointer is stored so that we only permit this
// exact circuit to be forwarded through the switch. If a circuit is
// already pending, the htlc will be reforwarded by the switch.
//
// NOTE: We track an additional addFails subsequence, which permits us
// to fail back all packets that weren't dropped if we encounter an
// error when committing the circuits.
cm.mtx.Lock()
var adds, drops, fails, addFails []*PaymentCircuit
for _, circuit := range circuits {
inKey := circuit.InKey()
if foundCircuit, ok := cm.pending[inKey]; ok {
switch {
// This circuit has a keystone, it's waiting for a
// response from the remote peer on the outgoing link.
// Drop it like it's hot, ensure duplicates get caught.
case foundCircuit.HasKeystone():
drops = append(drops, circuit)
// If no keystone is set and the switch has not been
// restarted, the corresponding packet should still be
// in the outgoing link's mailbox. It will be delivered
// if it comes online before the switch goes down.
//
// NOTE: Dropping here prevents a flapping, incoming
// link from failing a duplicate add while it is still
// in the server's memory mailboxes.
case !foundCircuit.LoadedFromDisk:
drops = append(drops, circuit)
// Otherwise, the in-mem packet has been lost due to a
// restart. It is now safe to send back a failure along
// the incoming link. The incoming link should be able
// detect and ignore duplicate packets of this type.
default:
fails = append(fails, circuit)
addFails = append(addFails, circuit)
}
continue
}
cm.pending[inKey] = circuit
adds = append(adds, circuit)
addFails = append(addFails, circuit)
}
cm.mtx.Unlock()
// If all circuits are dropped or failed, we are done.
if len(adds) == 0 {
actions.Drops = drops
actions.Fails = fails
return actions, nil
}
// Now, optimistically serialize the circuits to add.
var bs = make([]bytes.Buffer, len(adds))
for i, circuit := range adds {
if err := circuit.Encode(&bs[i]); err != nil {
actions.Drops = drops
actions.Fails = addFails
return actions, err
}
}
// Write the entire batch of circuits to the persistent circuit bucket
// using bolt's Batch write. This method must be called from multiple,
// distinct goroutines to have any impact on performance.
err := cm.db.Batch(func(tx *bolt.Tx) error {
circuitBkt := tx.Bucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}
for i, circuit := range adds {
inKeyBytes := circuit.InKey().Bytes()
circuitBytes := bs[i].Bytes()
err := circuitBkt.Put(inKeyBytes, circuitBytes)
if err != nil {
return err
}
}
return nil
})
// Return if the write succeeded.
if err == nil {
actions.Adds = adds
actions.Drops = drops
actions.Fails = fails
return actions, nil
}
// Otherwise, rollback the circuits added to the pending set if the
// write failed.
cm.mtx.Lock()
for _, circuit := range adds {
delete(cm.pending, circuit.InKey())
}
cm.mtx.Unlock()
// Since our write failed, we will return the dropped packets and mark
// all other circuits as failed.
actions.Drops = drops
actions.Fails = addFails
return actions, err
}
// Keystone is a tuple binding an incoming and outgoing CircuitKey. Keystones
// are preemptively written by an outgoing link before signing a new commitment
// state, and cements which HTLCs we are awaiting a response from a remote peer.
type Keystone struct {
InKey CircuitKey
OutKey CircuitKey
}
// String returns a human readable description of the Keystone.
func (k *Keystone) String() string {
return fmt.Sprintf("%s --> %s", k.InKey, k.OutKey)
}
// OpenCircuits sets the outgoing circuit key for the circuit identified by
// inKey, persistently marking the circuit as opened. After the changes have
// been persisted, the circuit map's in-memory indexes are updated so that this
// circuit can be queried using LookupByKeystone or LookupByPaymentHash.
func (cm *circuitMap) OpenCircuits(keystones ...Keystone) error {
if len(keystones) == 0 {
return nil
}
// Check that all keystones correspond to committed-but-unopened
// circuits.
cm.mtx.RLock()
openedCircuits := make([]*PaymentCircuit, 0, len(keystones))
for _, ks := range keystones {
if _, ok := cm.opened[ks.OutKey]; ok {
cm.mtx.RUnlock()
return ErrDuplicateKeystone
}
circuit, ok := cm.pending[ks.InKey]
if !ok {
cm.mtx.RUnlock()
return ErrUnknownCircuit
}
openedCircuits = append(openedCircuits, circuit)
}
cm.mtx.RUnlock()
err := cm.db.Update(func(tx *bolt.Tx) error {
// Now, load the circuit bucket to which we will write the
// already serialized circuit.
keystoneBkt := tx.Bucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
for _, ks := range keystones {
outBytes := ks.OutKey.Bytes()
inBytes := ks.InKey.Bytes()
err := keystoneBkt.Put(outBytes, inBytes)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return err
}
cm.mtx.Lock()
for i, circuit := range openedCircuits {
ks := keystones[i]
// Since our persistent operation was successful, we can now
// modify the in memory representations. Set the outgoing
// circuit key on our pending circuit, add the same circuit to
// set of opened circuits, and add this circuit to the hash
// index.
circuit.Outgoing = &CircuitKey{}
*circuit.Outgoing = ks.OutKey
cm.opened[ks.OutKey] = circuit
cm.addCircuitToHashIndex(circuit)
}
cm.mtx.Unlock()
return nil
}
// addCirciutToHashIndex inserts a circuit into the circuit map's hash index, so
// that it can be queried using LookupByPaymentHash.
func (cm *circuitMap) addCircuitToHashIndex(c *PaymentCircuit) {
if _, ok := cm.hashIndex[c.PaymentHash]; !ok {
cm.hashIndex[c.PaymentHash] = make(map[CircuitKey]struct{})
}
cm.hashIndex[c.PaymentHash][c.OutKey()] = struct{}{}
}
// FailCircuit marks the circuit identified by `inKey` as closing in-memory,
// which prevents duplicate settles/fails from completing an open circuit twice.
func (cm *circuitMap) FailCircuit(
inKey CircuitKey) (*PaymentCircuit, error) {
cm.mtx.Lock()
defer cm.mtx.Unlock()
circuit, ok := cm.pending[inKey]
if !ok {
return nil, ErrUnknownCircuit
}
_, ok = cm.closed[inKey]
if ok {
return nil, ErrCircuitClosing
}
cm.closed[inKey] = struct{}{}
return circuit, nil
}
// CloseCircuit marks the circuit identified by `outKey` as closing
// in-memory, which prevents duplicate settles/fails from completing an open
// circuit twice.
func (cm *circuitMap) CloseCircuit(
outKey CircuitKey) (*PaymentCircuit, error) {
cm.mtx.Lock()
defer cm.mtx.Unlock()
circuit, ok := cm.opened[outKey]
if !ok {
return nil, ErrUnknownCircuit
}
_, ok = cm.closed[circuit.Incoming]
if ok {
return nil, ErrCircuitClosing
}
cm.closed[circuit.Incoming] = struct{}{}
return circuit, nil
}
// DeleteCircuits destroys the target circuit by removing it from the circuit map,
// additionally removing the circuit's keystone if the HTLC was forwarded
// through an outgoing link. The circuit should be identified by its incoming
// circuit key.
func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
var (
closingCircuits = make(map[CircuitKey]struct{})
removedCircuits = make(map[CircuitKey]*PaymentCircuit)
)
cm.mtx.Lock()
// First check that all provided keys are still known to the circuit
// map.
for _, inKey := range inKeys {
if _, ok := cm.pending[inKey]; !ok {
cm.mtx.Unlock()
return ErrUnknownCircuit
}
}
// If no offenders were found, remove any references to the circuit from
// memory, keeping track of which circuits were removed, and which ones
// had been marked closed. This can be used to restore these entries
// later if the persistent removal fails.
for _, inKey := range inKeys {
circuit := cm.pending[inKey]
delete(cm.pending, inKey)
if _, ok := cm.closed[inKey]; ok {
closingCircuits[inKey] = struct{}{}
delete(cm.closed, inKey)
}
if circuit.HasKeystone() {
delete(cm.opened, circuit.OutKey())
cm.removeCircuitFromHashIndex(circuit)
}
removedCircuits[inKey] = circuit
}
cm.mtx.Unlock()
err := cm.db.Batch(func(tx *bolt.Tx) error {
for _, circuit := range removedCircuits {
// If this htlc made it to an outgoing link, load the
// keystone bucket from which we will remove the
// outgoing circuit key.
if circuit.HasKeystone() {
keystoneBkt := tx.Bucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
}
outKey := circuit.OutKey()
err := keystoneBkt.Delete(outKey.Bytes())
if err != nil {
return err
}
}
// Remove the circuit itself based on the incoming
// circuit key.
circuitBkt := tx.Bucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
}
inKey := circuit.InKey()
if err := circuitBkt.Delete(inKey.Bytes()); err != nil {
return err
}
}
return nil
})
// Return if the write succeeded.
if err == nil {
return nil
}
// If the persistent changes failed, restore the circuit map to it's
// previous state.
cm.mtx.Lock()
for inKey, circuit := range removedCircuits {
cm.pending[inKey] = circuit
if _, ok := closingCircuits[inKey]; ok {
cm.closed[inKey] = struct{}{}
}
if circuit.HasKeystone() {
cm.opened[circuit.OutKey()] = circuit
cm.addCircuitToHashIndex(circuit)
}
}
cm.mtx.Unlock()
return err
}
// removeCircuitFromHashIndex removes the given circuit from the hash index,
// pruning any unnecessary memory optimistically.
func (cm *circuitMap) removeCircuitFromHashIndex(c *PaymentCircuit) {
// Locate bucket containing this circuit's payment hashes.
circuitsWithHash, ok := cm.hashIndex[c.PaymentHash]
if !ok {
return
}
outKey := c.OutKey()
// Remove this circuit from the set of circuitsWithHash.
delete(circuitsWithHash, outKey)
// Prune the payment hash bucket if no other entries remain.
if len(circuitsWithHash) == 0 {
delete(cm.hashIndex, c.PaymentHash)
}
}
// NumPending returns the number of active circuits added to the circuit map.
func (cm *circuitMap) NumPending() int {
cm.mtx.RLock()
defer cm.mtx.RUnlock()
return len(cm.pending)
}
// NumOpen returns the number of circuits that have been opened by way of
// setting their keystones. This is the number of HTLCs that are waiting for a
// settle/fail response from a remote peer.
func (cm *circuitMap) NumOpen() int {
cm.mtx.RLock()
defer cm.mtx.RUnlock()
return len(cm.opened)
}

@ -1,155 +1,1312 @@
package htlcswitch_test
import (
"bytes"
"io/ioutil"
"reflect"
"testing"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcutil"
)
func TestCircuitMap(t *testing.T) {
var (
hash1 = [32]byte{0x01}
hash2 = [32]byte{0x02}
hash3 = [32]byte{0x03}
)
func TestCircuitMapInit(t *testing.T) {
t.Parallel()
var hash1, hash2, hash3 [32]byte
hash1[0] = 1
hash2[0] = 2
hash3[0] = 3
// Initialize new database for circuit map.
cdb := makeCircuitDB(t, "")
_, err := htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
restartCircuitMap(t, cdb)
}
var halfCircuitTests = []struct {
hash [32]byte
inValue btcutil.Amount
outValue btcutil.Amount
chanID lnwire.ShortChannelID
htlcID uint64
encrypter htlcswitch.ErrorEncrypter
}{
{
hash: hash1,
inValue: 0,
outValue: 1000,
chanID: lnwire.NewShortChanIDFromInt(1),
htlcID: 1,
encrypter: nil,
},
{
hash: hash2,
inValue: 2100,
outValue: 2000,
chanID: lnwire.NewShortChanIDFromInt(2),
htlcID: 2,
encrypter: htlcswitch.NewMockObfuscator(),
},
{
hash: hash3,
inValue: 10000,
outValue: 9000,
chanID: lnwire.NewShortChanIDFromInt(3),
htlcID: 3,
encrypter: htlcswitch.NewSphinxErrorEncrypter(),
},
}
// TestHalfCircuitSerialization checks that the half circuits can be properly
// encoded and decoded properly. A critical responsibility of this test is to
// verify that the various ErrorEncrypter implementations can be properly
// reconstructed from a serialized half circuit.
func TestHalfCircuitSerialization(t *testing.T) {
t.Parallel()
for i, test := range halfCircuitTests {
circuit := &htlcswitch.PaymentCircuit{
PaymentHash: test.hash,
IncomingAmount: lnwire.NewMSatFromSatoshis(test.inValue),
OutgoingAmount: lnwire.NewMSatFromSatoshis(test.outValue),
Incoming: htlcswitch.CircuitKey{
ChanID: test.chanID,
HtlcID: test.htlcID,
},
ErrorEncrypter: test.encrypter,
}
// Write the half circuit to our buffer.
var b bytes.Buffer
if err := circuit.Encode(&b); err != nil {
t.Fatalf("unable to encode half payment circuit test=%d: %v", i, err)
}
// Then try to decode the serialized bytes.
var circuit2 htlcswitch.PaymentCircuit
circuitReader := bytes.NewReader(b.Bytes())
if err := circuit2.Decode(circuitReader); err != nil {
t.Fatalf("unable to decode half payment circuit test=%d: %v", i, err)
}
// Reconstructed half circuit should match the original.
if !equalIgnoreLFD(circuit, &circuit2) {
t.Fatalf("unexpected half circuit test=%d, want %v, got %v",
i, circuit, circuit2)
}
}
}
func TestCircuitMapPersistence(t *testing.T) {
t.Parallel()
var (
chan1 = lnwire.NewShortChanIDFromInt(1)
chan2 = lnwire.NewShortChanIDFromInt(2)
chan1 = lnwire.NewShortChanIDFromInt(1)
chan2 = lnwire.NewShortChanIDFromInt(2)
circuitMap htlcswitch.CircuitMap
err error
)
circuitMap := htlcswitch.NewCircuitMap()
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := circuitMap.LookupByHTLC(chan1, 0)
circuit := circuitMap.LookupCircuit(htlcswitch.CircuitKey{chan1, 0})
if circuit != nil {
t.Fatalf("LookupByHTLC returned a circuit before any were added: %v",
circuit)
}
circuit1 := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan2,
HtlcID: 1,
},
PaymentHash: hash1,
ErrorEncrypter: htlcswitch.NewMockObfuscator(),
}
if _, err := circuitMap.CommitCircuits(circuit1); err != nil {
t.Fatalf("unable to add half circuit: %v", err)
}
// Circuit map should have one circuit that has not been fully opened.
assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertHasCircuit(t, circuitMap, circuit1)
cdb, circuitMap = restartCircuitMap(t, cdb)
assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertHasCircuit(t, circuitMap, circuit1)
// Add multiple circuits with same destination channel but different HTLC
// IDs and payment hashes.
circuitMap.Add(&htlcswitch.PaymentCircuit{
PaymentHash: hash1,
IncomingChanID: chan2,
IncomingHTLCID: 1,
OutgoingChanID: chan1,
OutgoingHTLCID: 0,
})
keystone1 := htlcswitch.Keystone{
InKey: circuit1.Incoming,
OutKey: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 0,
},
}
circuit1.Outgoing = &keystone1.OutKey
if err := circuitMap.OpenCircuits(keystone1); err != nil {
t.Fatalf("unable to add full circuit: %v", err)
}
circuitMap.Add(&htlcswitch.PaymentCircuit{
// Circuit map should reflect addition of circuit1, and the change
// should survive a restart.
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit1)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
cdb, circuitMap = restartCircuitMap(t, cdb)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit1)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
circuit2 := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan2,
HtlcID: 2,
},
PaymentHash: hash2,
IncomingChanID: chan2,
IncomingHTLCID: 2,
OutgoingChanID: chan1,
OutgoingHTLCID: 1,
})
ErrorEncrypter: htlcswitch.NewMockObfuscator(),
}
if _, err := circuitMap.CommitCircuits(circuit2); err != nil {
t.Fatalf("unable to add half circuit: %v", err)
}
assertHasCircuit(t, circuitMap, circuit2)
keystone2 := htlcswitch.Keystone{
InKey: circuit2.Incoming,
OutKey: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 1,
},
}
circuit2.Outgoing = &keystone2.OutKey
if err := circuitMap.OpenCircuits(keystone2); err != nil {
t.Fatalf("unable to add full circuit: %v", err)
}
// Should have two full circuits, one under hash1 and another under
// hash2. Both half payment circuits should have been removed when the
// full circuits were added.
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit1)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
assertNumCircuitsWithHash(t, circuitMap, hash2, 1)
assertHasCircuit(t, circuitMap, circuit2)
assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2)
assertNumCircuitsWithHash(t, circuitMap, hash3, 0)
cdb, circuitMap = restartCircuitMap(t, cdb)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit1)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
assertNumCircuitsWithHash(t, circuitMap, hash2, 1)
assertHasCircuit(t, circuitMap, circuit2)
assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2)
assertNumCircuitsWithHash(t, circuitMap, hash3, 0)
circuit3 := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 2,
},
PaymentHash: hash3,
ErrorEncrypter: htlcswitch.NewMockObfuscator(),
}
if _, err := circuitMap.CommitCircuits(circuit3); err != nil {
t.Fatalf("unable to add half circuit: %v", err)
}
assertHasCircuit(t, circuitMap, circuit3)
cdb, circuitMap = restartCircuitMap(t, cdb)
assertHasCircuit(t, circuitMap, circuit3)
// Add another circuit with an already-used HTLC ID but different
// destination channel.
circuitMap.Add(&htlcswitch.PaymentCircuit{
PaymentHash: hash3,
IncomingChanID: chan1,
IncomingHTLCID: 2,
OutgoingChanID: chan2,
OutgoingHTLCID: 0,
})
circuit = circuitMap.LookupByHTLC(chan1, 0)
if circuit == nil {
t.Fatal("LookupByHTLC failed to find circuit")
keystone3 := htlcswitch.Keystone{
InKey: circuit3.Incoming,
OutKey: htlcswitch.CircuitKey{
ChanID: chan2,
HtlcID: 0,
},
}
if circuit.PaymentHash != hash1 || circuit.IncomingHTLCID != 1 {
t.Fatalf("LookupByHTLC found unexpected circuit: %v", circuit)
circuit3.Outgoing = &keystone3.OutKey
if err := circuitMap.OpenCircuits(keystone3); err != nil {
t.Fatalf("unable to add full circuit: %v", err)
}
circuit = circuitMap.LookupByHTLC(chan1, 1)
if circuit == nil {
t.Fatal("LookupByHTLC failed to find circuit")
}
if circuit.PaymentHash != hash2 || circuit.IncomingHTLCID != 2 {
t.Fatalf("LookupByHTLC found unexpected circuit: %v", circuit)
}
// Check that all have been marked as full circuits, and that no half
// circuits are currently being tracked.
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2)
assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3)
cdb, circuitMap = restartCircuitMap(t, cdb)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2)
assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3)
circuit = circuitMap.LookupByHTLC(chan2, 0)
if circuit == nil {
t.Fatal("LookupByHTLC failed to find circuit")
// Even though a circuit was added with chan1, HTLC ID 2 as the source,
// the lookup should go by destination channel, HTLC ID.
invalidKeystone := htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 2,
}
if circuit.PaymentHash != hash3 || circuit.IncomingHTLCID != 2 {
t.Fatalf("LookupByHTLC found unexpected circuit: %v", circuit)
}
// Even though a circuit was added with chan1, HTLC ID 2 as the source, the
// lookup should go by destination channel, HTLC ID.
circuit = circuitMap.LookupByHTLC(chan1, 2)
circuit = circuitMap.LookupOpenCircuit(invalidKeystone)
if circuit != nil {
t.Fatalf("LookupByHTLC returned a circuit without being added: %v",
circuit)
}
circuit4 := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan2,
HtlcID: 3,
},
PaymentHash: hash1,
ErrorEncrypter: htlcswitch.NewMockObfuscator(),
}
if _, err := circuitMap.CommitCircuits(circuit4); err != nil {
t.Fatalf("unable to add half circuit: %v", err)
}
// Circuit map should still only show one circuit with hash1, since we
// have not set the keystone for circuit4.
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit4)
cdb, circuitMap = restartCircuitMap(t, cdb)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit4)
// Add a circuit with a destination channel and payment hash that are
// already added but a different HTLC ID.
circuitMap.Add(&htlcswitch.PaymentCircuit{
PaymentHash: hash1,
IncomingChanID: chan2,
IncomingHTLCID: 3,
OutgoingChanID: chan1,
OutgoingHTLCID: 3,
})
circuit = circuitMap.LookupByHTLC(chan1, 3)
if circuit == nil {
t.Fatal("LookupByHTLC failed to find circuit")
keystone4 := htlcswitch.Keystone{
InKey: circuit4.Incoming,
OutKey: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
}
if circuit.PaymentHash != hash1 || circuit.IncomingHTLCID != 3 {
t.Fatalf("LookupByHTLC found unexpected circuit: %v", circuit)
circuit4.Outgoing = &keystone4.OutKey
if err := circuitMap.OpenCircuits(keystone4); err != nil {
t.Fatalf("unable to add full circuit: %v", err)
}
// Check lookups by payment hash.
circuits := circuitMap.LookupByPaymentHash(hash1)
if len(circuits) != 2 {
t.Fatalf("LookupByPaymentHash returned wrong number of circuits for "+
"hash1: expected %d, got %d", 2, len(circuits))
}
// Verify that all circuits have been fully added.
assertHasCircuit(t, circuitMap, circuit1)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
assertHasCircuit(t, circuitMap, circuit2)
assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2)
assertHasCircuit(t, circuitMap, circuit3)
assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3)
assertHasCircuit(t, circuitMap, circuit4)
assertHasKeystone(t, circuitMap, keystone4.OutKey, circuit4)
circuits = circuitMap.LookupByPaymentHash(hash2)
if len(circuits) != 1 {
t.Fatalf("LookupByPaymentHash returned wrong number of circuits for "+
"hash2: expected %d, got %d", 1, len(circuits))
}
// Verify that each circuit is exposed via the proper hash bucketing.
assertNumCircuitsWithHash(t, circuitMap, hash1, 2)
assertHasCircuitForHash(t, circuitMap, hash1, circuit1)
assertHasCircuitForHash(t, circuitMap, hash1, circuit4)
assertNumCircuitsWithHash(t, circuitMap, hash2, 1)
assertHasCircuitForHash(t, circuitMap, hash2, circuit2)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
assertHasCircuitForHash(t, circuitMap, hash3, circuit3)
// Restart, then run checks again.
cdb, circuitMap = restartCircuitMap(t, cdb)
// Verify that all circuits have been fully added.
assertHasCircuit(t, circuitMap, circuit1)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
assertHasCircuit(t, circuitMap, circuit2)
assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2)
assertHasCircuit(t, circuitMap, circuit3)
assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3)
assertHasCircuit(t, circuitMap, circuit4)
assertHasKeystone(t, circuitMap, keystone4.OutKey, circuit4)
// Verify that each circuit is exposed via the proper hash bucketing.
assertNumCircuitsWithHash(t, circuitMap, hash1, 2)
assertHasCircuitForHash(t, circuitMap, hash1, circuit1)
assertHasCircuitForHash(t, circuitMap, hash1, circuit4)
assertNumCircuitsWithHash(t, circuitMap, hash2, 1)
assertHasCircuitForHash(t, circuitMap, hash2, circuit2)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
assertHasCircuitForHash(t, circuitMap, hash3, circuit3)
// Test removing circuits and the subsequent lookups.
err := circuitMap.Remove(chan1, 0)
err = circuitMap.DeleteCircuits(circuit1.Incoming)
if err != nil {
t.Fatalf("Remove returned unexpected error: %v", err)
}
circuits = circuitMap.LookupByPaymentHash(hash1)
if len(circuits) != 1 {
t.Fatalf("LookupByPaymentHash returned wrong number of circuits for "+
"hash1: expecected %d, got %d", 1, len(circuits))
}
if circuits[0].OutgoingHTLCID != 3 {
t.Fatalf("LookupByPaymentHash returned wrong circuit for hash1: %v",
circuits[0])
}
// There should be exactly one remaining circuit with hash1, and it
// should be circuit4.
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuitForHash(t, circuitMap, hash1, circuit4)
cdb, circuitMap = restartCircuitMap(t, cdb)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuitForHash(t, circuitMap, hash1, circuit4)
// Removing already-removed circuit should return an error.
err = circuitMap.Remove(chan1, 0)
err = circuitMap.DeleteCircuits(circuit1.Incoming)
if err == nil {
t.Fatal("Remove did not return expected not found error")
}
// Verify that nothing related to hash1 has changed
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuitForHash(t, circuitMap, hash1, circuit4)
// Remove last remaining circuit with payment hash hash1.
err = circuitMap.Remove(chan1, 3)
err = circuitMap.DeleteCircuits(circuit4.Incoming)
if err != nil {
t.Fatalf("Remove returned unexpected error: %v", err)
}
circuits = circuitMap.LookupByPaymentHash(hash1)
if len(circuits) != 0 {
t.Fatalf("LookupByPaymentHash returned wrong number of circuits for "+
"hash1: expecected %d, got %d", 0, len(circuits))
assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertNumCircuitsWithHash(t, circuitMap, hash2, 1)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
cdb, circuitMap = restartCircuitMap(t, cdb)
assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertNumCircuitsWithHash(t, circuitMap, hash2, 1)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
// Remove last remaining circuit with payment hash hash2.
err = circuitMap.DeleteCircuits(circuit2.Incoming)
if err != nil {
t.Fatalf("Remove returned unexpected error: %v", err)
}
// There should now only be one remaining circuit, with hash3.
assertNumCircuitsWithHash(t, circuitMap, hash2, 0)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
cdb, circuitMap = restartCircuitMap(t, cdb)
assertNumCircuitsWithHash(t, circuitMap, hash2, 0)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
// Remove last remaining circuit with payment hash hash3.
err = circuitMap.DeleteCircuits(circuit3.Incoming)
if err != nil {
t.Fatalf("Remove returned unexpected error: %v", err)
}
// Check that the circuit map is empty, even after restarting.
assertNumCircuitsWithHash(t, circuitMap, hash3, 0)
cdb, circuitMap = restartCircuitMap(t, cdb)
assertNumCircuitsWithHash(t, circuitMap, hash3, 0)
}
// assertHasKeystone tests that the circuit map contains the provided payment
// circuit.
func assertHasKeystone(t *testing.T, cm htlcswitch.CircuitMap,
outKey htlcswitch.CircuitKey, c *htlcswitch.PaymentCircuit) {
circuit := cm.LookupOpenCircuit(outKey)
if !equalIgnoreLFD(circuit, c) {
t.Fatalf("unexpected circuit, want: %v, got %v", c, circuit)
}
}
// assertDoesNotHaveKeystone tests that the circuit map does not contain a
// circuit for the provided outgoing circuit key.
func assertDoesNotHaveKeystone(t *testing.T, cm htlcswitch.CircuitMap,
outKey htlcswitch.CircuitKey) {
circuit := cm.LookupOpenCircuit(outKey)
if circuit != nil {
t.Fatalf("expected no circuit for keystone %s, found %v",
outKey, circuit)
}
}
// assertHasCircuitForHash tests that the provided circuit appears in the list
// of circuits for the given hash.
func assertHasCircuitForHash(t *testing.T, cm htlcswitch.CircuitMap, hash [32]byte,
circuit *htlcswitch.PaymentCircuit) {
circuits := cm.LookupByPaymentHash(hash)
for _, c := range circuits {
if equalIgnoreLFD(c, circuit) {
return
}
}
t.Fatalf("unable to find circuit: %v by hash: %v", circuit, hash)
}
// assertNumCircuitsWithHash tests that the circuit has the right number of full
// circuits, indexed by the given hash.
func assertNumCircuitsWithHash(t *testing.T, cm htlcswitch.CircuitMap,
hash [32]byte, expectedNum int) {
circuits := cm.LookupByPaymentHash(hash)
if len(circuits) != expectedNum {
t.Fatalf("LookupByPaymentHash returned wrong number of circuits for "+
"hash=%v: expecected %d, got %d", hash, expectedNum,
len(circuits))
}
}
// assertHasCircuit queries the circuit map using the half-circuit's half
// key, and fails if the returned half-circuit differs from the provided one.
func assertHasCircuit(t *testing.T, cm htlcswitch.CircuitMap,
c *htlcswitch.PaymentCircuit) {
c2 := cm.LookupCircuit(c.Incoming)
if !equalIgnoreLFD(c, c2) {
t.Fatalf("expected circuit: %v, got %v", c, c2)
}
}
// equalIgnoreLFD compares two payment circuits, but ignores the current value
// of LoadedFromDisk. The value is temporarily set to false for the comparison
// and then restored.
func equalIgnoreLFD(c, c2 *htlcswitch.PaymentCircuit) bool {
ogLFD := c.LoadedFromDisk
ogLFD2 := c2.LoadedFromDisk
c.LoadedFromDisk = false
c2.LoadedFromDisk = false
isEqual := reflect.DeepEqual(c, c2)
c.LoadedFromDisk = ogLFD
c2.LoadedFromDisk = ogLFD2
return isEqual
}
// assertDoesNotHaveCircuit queries the circuit map using the circuit's
// incoming circuit key, and fails if it is found.
func assertDoesNotHaveCircuit(t *testing.T, cm htlcswitch.CircuitMap,
c *htlcswitch.PaymentCircuit) {
c2 := cm.LookupCircuit(c.Incoming)
if c2 != nil {
t.Fatalf("expected no circuit for %v, got %v", c, c2)
}
}
// makeCircuitDB initializes a new test channeldb for testing the persistence of
// the circuit map. If an empty string is provided as a path, a temp directory
// will be created.
func makeCircuitDB(t *testing.T, path string) *channeldb.DB {
if path == "" {
var err error
path, err = ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to create temp path: %v", err)
}
}
db, err := channeldb.Open(path)
if err != nil {
t.Fatalf("unable to open channel db: %v", err)
}
return db
}
// Creates a new circuit map, backed by a freshly opened channeldb. The existing
// channeldb is closed in order to simulate a complete restart.
func restartCircuitMap(t *testing.T, cdb *channeldb.DB) (*channeldb.DB,
htlcswitch.CircuitMap) {
// Record the current temp path and close current db.
dbPath := cdb.Path()
cdb.Close()
// Reinitialize circuit map with same db path.
cdb2 := makeCircuitDB(t, dbPath)
cm2, err := htlcswitch.NewCircuitMap(cdb2)
if err != nil {
t.Fatalf("unable to recreate persistent circuit map: %v", err)
}
return cdb2, cm2
}
// TestCircuitMapCommitCircuits tests the following behavior of CommitCircuits:
// 1. New circuits are successfully added.
// 2. Duplicate circuits are dropped anytime before circuit map shutsdown.
// 3. Duplicate circuits are failed anytime after circuit map restarts.
func TestCircuitMapCommitCircuits(t *testing.T) {
t.Parallel()
var (
chan1 = lnwire.NewShortChanIDFromInt(1)
circuitMap htlcswitch.CircuitMap
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
}
// First we will try to add an new circuit to the circuit map, this
// should succeed.
actions, err := circuitMap.CommitCircuits(circuit)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
if len(actions.Drops) > 0 {
t.Fatalf("new circuit should not have been dropped")
}
if len(actions.Fails) > 0 {
t.Fatalf("new circuit should not have failed")
}
if len(actions.Adds) != 1 {
t.Fatalf("only one circuit should have been added, found %d",
len(actions.Adds))
}
circuit2 := circuitMap.LookupCircuit(circuit.Incoming)
if !reflect.DeepEqual(circuit, circuit2) {
t.Fatalf("unexpected committed circuit: got %v, want %v",
circuit2, circuit)
}
// Then we will try to readd the same circuit again, this should result
// in the circuit being dropped. This can happen if the incoming link
// flaps.
actions, err = circuitMap.CommitCircuits(circuit)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
if len(actions.Adds) > 0 {
t.Fatalf("duplicate circuit should not have been added to circuit map")
}
if len(actions.Fails) > 0 {
t.Fatalf("duplicate circuit should not have failed")
}
if len(actions.Drops) != 1 {
t.Fatalf("only one circuit should have been dropped, found %d",
len(actions.Drops))
}
// Finally, restart the circuit map, which will cause the added circuit
// to be loaded from disk. Since the keystone was never set, subsequent
// attempts to commit the circuit should cause the circuit map to
// indicate that that the HTLC should be failed back.
cdb, circuitMap = restartCircuitMap(t, cdb)
actions, err = circuitMap.CommitCircuits(circuit)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
if len(actions.Adds) > 0 {
t.Fatalf("duplicate circuit with incomplete forwarding " +
"decision should not have been added to circuit map")
}
if len(actions.Drops) > 0 {
t.Fatalf("duplicate circuit with incomplete forwarding " +
"decision should not have been dropped by circuit map")
}
if len(actions.Fails) != 1 {
t.Fatalf("only one duplicate circuit with incomplete "+
"forwarding decision should have been failed, found: "+
"%d", len(actions.Fails))
}
// Lookup the committed circuit again, it should be identical apart from
// the loaded from disk flag.
circuit2 = circuitMap.LookupCircuit(circuit.Incoming)
if !equalIgnoreLFD(circuit, circuit2) {
t.Fatalf("unexpected committed circuit: got %v, want %v",
circuit2, circuit)
}
}
// TestCircuitMapOpenCircuits checks that circuits are properly opened, and that
// duplicate attempts to open a circuit will result in an error.
func TestCircuitMapOpenCircuits(t *testing.T) {
t.Parallel()
var (
chan1 = lnwire.NewShortChanIDFromInt(1)
chan2 = lnwire.NewShortChanIDFromInt(2)
circuitMap htlcswitch.CircuitMap
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
}
// First we will try to add an new circuit to the circuit map, this
// should succeed.
_, err = circuitMap.CommitCircuits(circuit)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
keystone := htlcswitch.Keystone{
InKey: circuit.Incoming,
OutKey: htlcswitch.CircuitKey{
ChanID: chan2,
HtlcID: 2,
},
}
// Open the circuit for the first time.
err = circuitMap.OpenCircuits(keystone)
if err != nil {
t.Fatalf("failed to open circuits: %v", err)
}
// Check that we can retrieve the open circuit if the circuit map before
// the circuit map is restarted.
circuit2 := circuitMap.LookupOpenCircuit(keystone.OutKey)
if !reflect.DeepEqual(circuit, circuit2) {
t.Fatalf("unexpected open circuit: got %v, want %v",
circuit2, circuit)
}
if !circuit2.HasKeystone() {
t.Fatalf("open circuit should have keystone")
}
if !reflect.DeepEqual(&keystone.OutKey, circuit2.Outgoing) {
t.Fatalf("expected open circuit to have outgoing key: %v, found %v",
&keystone.OutKey, circuit2.Outgoing)
}
// Open the circuit for a second time, which should fail due to a
// duplicate keystone
err = circuitMap.OpenCircuits(keystone)
if err != htlcswitch.ErrDuplicateKeystone {
t.Fatalf("failed to open circuits: %v", err)
}
// Then we will try to readd the same circuit again, this should result
// in the circuit being dropped. This can happen if the incoming link
// flaps OR the switch is entirely restarted and the outgoing link has
// not received a response.
actions, err := circuitMap.CommitCircuits(circuit)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
if len(actions.Adds) > 0 {
t.Fatalf("duplicate circuit should not have been added to circuit map")
}
if len(actions.Fails) > 0 {
t.Fatalf("duplicate circuit should not have failed")
}
if len(actions.Drops) != 1 {
t.Fatalf("only one circuit should have been dropped, found %d",
len(actions.Drops))
}
// Now, restart the circuit map, which will cause the opened circuit to
// be loaded from disk. Since we set the keystone on this circuit, it
// should be restored as such in memory.
//
// NOTE: The channel db doesn't have any channel data, so no keystones
// will be trimmed.
cdb, circuitMap = restartCircuitMap(t, cdb)
// Check that we can still query for the open circuit.
circuit2 = circuitMap.LookupOpenCircuit(keystone.OutKey)
if !equalIgnoreLFD(circuit, circuit2) {
t.Fatalf("unexpected open circuit: got %v, want %v",
circuit2, circuit)
}
// Try to open the circuit again, we expect this to fail since the open
// circuit was restored.
err = circuitMap.OpenCircuits(keystone)
if err != htlcswitch.ErrDuplicateKeystone {
t.Fatalf("failed to open circuits: %v", err)
}
// Lastly, with the circuit map restarted, try one more time to recommit
// the open circuit. This should be dropped, and is expected to happen
// if the incoming link flaps OR the switch is entirely restarted and
// the outgoing link has not received a response.
actions, err = circuitMap.CommitCircuits(circuit)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
if len(actions.Adds) > 0 {
t.Fatalf("duplicate circuit should not have been added to circuit map")
}
if len(actions.Fails) > 0 {
t.Fatalf("duplicate circuit should not have failed")
}
if len(actions.Drops) != 1 {
t.Fatalf("only one circuit should have been dropped, found %d",
len(actions.Drops))
}
}
func assertCircuitsOpenedPreRestart(t *testing.T,
circuitMap htlcswitch.CircuitMap,
circuits []*htlcswitch.PaymentCircuit,
keystones []htlcswitch.Keystone) {
for i, circuit := range circuits {
keystone := keystones[i]
openCircuit := circuitMap.LookupOpenCircuit(keystone.OutKey)
if !reflect.DeepEqual(circuit, openCircuit) {
t.Fatalf("unexpected open circuit %d: got %v, want %v",
i, openCircuit, circuit)
}
if !openCircuit.HasKeystone() {
t.Fatalf("open circuit %d should have keystone", i)
}
if !reflect.DeepEqual(&keystone.OutKey, openCircuit.Outgoing) {
t.Fatalf("expected open circuit %d to have outgoing "+
"key: %v, found %v", i,
&keystone.OutKey, openCircuit.Outgoing)
}
}
}
func assertCircuitsOpenedPostRestart(t *testing.T,
circuitMap htlcswitch.CircuitMap,
circuits []*htlcswitch.PaymentCircuit,
keystones []htlcswitch.Keystone) {
for i, circuit := range circuits {
keystone := keystones[i]
openCircuit := circuitMap.LookupOpenCircuit(keystone.OutKey)
if !equalIgnoreLFD(circuit, openCircuit) {
t.Fatalf("unexpected open circuit %d: got %v, want %v",
i, openCircuit, circuit)
}
if !openCircuit.HasKeystone() {
t.Fatalf("open circuit %d should have keystone", i)
}
if !reflect.DeepEqual(&keystone.OutKey, openCircuit.Outgoing) {
t.Fatalf("expected open circuit %d to have outgoing "+
"key: %v, found %v", i,
&keystone.OutKey, openCircuit.Outgoing)
}
}
}
func assertCircuitsNotOpenedPreRestart(t *testing.T,
circuitMap htlcswitch.CircuitMap,
circuits []*htlcswitch.PaymentCircuit,
keystones []htlcswitch.Keystone,
offset int) {
for i := range circuits {
keystone := keystones[i]
openCircuit := circuitMap.LookupOpenCircuit(keystone.OutKey)
if openCircuit != nil {
t.Fatalf("expected circuit %d not to be open",
offset+i)
}
circuit := circuitMap.LookupCircuit(keystone.InKey)
if circuit == nil {
t.Fatalf("expected to find unopened circuit %d",
offset+i)
}
if circuit.HasKeystone() {
t.Fatalf("circuit %d should not have keystone",
offset+i)
}
}
}
// TestCircuitMapTrimOpenCircuits verifies that the circuit map properly removes
// circuits from disk and the in-memory state when TrimOpenCircuits is used.
// This test checks that a successful trim survives a restart, and that circuits
// added before the restart can also be trimmed.
func TestCircuitMapTrimOpenCircuits(t *testing.T) {
t.Parallel()
var (
chan1 = lnwire.NewShortChanIDFromInt(1)
chan2 = lnwire.NewShortChanIDFromInt(2)
circuitMap htlcswitch.CircuitMap
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
const nCircuits = 10
const firstTrimIndex = 7
const secondTrimIndex = 3
// Create a list of all circuits that will be committed in the circuit
// map. The incoming HtlcIDs are chosen so that there is overlap with
// the outgoing HtlcIDs, but ensures that the test is not dependent on
// them being equal.
circuits := make([]*htlcswitch.PaymentCircuit, nCircuits)
for i := range circuits {
circuits[i] = &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: uint64(i + 3),
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
}
}
// First we will try to add an new circuit to the circuit map, this
// should succeed.
_, err = circuitMap.CommitCircuits(circuits...)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
// Now create a list of the keystones that we will use to preemptively
// open the circuits. We set the index as the outgoing HtlcID to i
// simplify the indexing logic of the test.
keystones := make([]htlcswitch.Keystone, nCircuits)
for i := range keystones {
keystones[i] = htlcswitch.Keystone{
InKey: circuits[i].Incoming,
OutKey: htlcswitch.CircuitKey{
ChanID: chan2,
HtlcID: uint64(i),
},
}
}
// Open the circuits for the first time.
err = circuitMap.OpenCircuits(keystones...)
if err != nil {
t.Fatalf("failed to open circuits: %v", err)
}
// Check that all circuits are marked open.
assertCircuitsOpenedPreRestart(t, circuitMap, circuits, keystones)
// Now trim up above outgoing htlcid `firstTrimIndex` (7). This should
// leave the first 7 circuits open, and the rest should be reverted to
// an unopened state.
err = circuitMap.TrimOpenCircuits(chan2, firstTrimIndex)
if err != nil {
t.Fatalf("unable to trim circuits")
}
assertCircuitsOpenedPreRestart(t,
circuitMap,
circuits[:firstTrimIndex],
keystones[:firstTrimIndex],
)
assertCircuitsNotOpenedPreRestart(
t,
circuitMap,
circuits[firstTrimIndex:],
keystones[firstTrimIndex:],
firstTrimIndex,
)
// Restart the circuit map, verify that that the trim is reflected on
// startup.
cdb, circuitMap = restartCircuitMap(t, cdb)
assertCircuitsOpenedPostRestart(
t,
circuitMap,
circuits[:firstTrimIndex],
keystones[:firstTrimIndex],
)
assertCircuitsNotOpenedPreRestart(
t,
circuitMap,
circuits[firstTrimIndex:],
keystones[firstTrimIndex:],
firstTrimIndex,
)
// Now, trim above outgoing htlcid `secondTrimIndex` (3). Only the first
// three circuits should be open, with any others being reverted back to
// unopened.
err = circuitMap.TrimOpenCircuits(chan2, secondTrimIndex)
if err != nil {
t.Fatalf("unable to trim circuits")
}
assertCircuitsOpenedPostRestart(
t,
circuitMap,
circuits[:secondTrimIndex],
keystones[:secondTrimIndex],
)
assertCircuitsNotOpenedPreRestart(
t,
circuitMap,
circuits[secondTrimIndex:],
keystones[secondTrimIndex:],
secondTrimIndex,
)
// Restart the circuit map one last time to make sure the changes are
// persisted.
cdb, circuitMap = restartCircuitMap(t, cdb)
assertCircuitsOpenedPostRestart(
t,
circuitMap,
circuits[:secondTrimIndex],
keystones[:secondTrimIndex],
)
assertCircuitsNotOpenedPreRestart(
t,
circuitMap,
circuits[secondTrimIndex:],
keystones[secondTrimIndex:],
secondTrimIndex,
)
}
// TestCircuitMapCloseOpenCircuits asserts that the circuit map can properly
// close open circuits, and that it allows at most one response to do so
// successfully. It also checks that a circuit is reopened if the close was not
// persisted via DeleteCircuits, and can again be closed.
func TestCircuitMapCloseOpenCircuits(t *testing.T) {
t.Parallel()
var (
chan1 = lnwire.NewShortChanIDFromInt(1)
chan2 = lnwire.NewShortChanIDFromInt(2)
circuitMap htlcswitch.CircuitMap
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
}
// First we will try to add an new circuit to the circuit map, this
// should succeed.
_, err = circuitMap.CommitCircuits(circuit)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
keystone := htlcswitch.Keystone{
InKey: circuit.Incoming,
OutKey: htlcswitch.CircuitKey{
ChanID: chan2,
HtlcID: 2,
},
}
// Open the circuit for the first time.
err = circuitMap.OpenCircuits(keystone)
if err != nil {
t.Fatalf("failed to open circuits: %v", err)
}
// Check that we can retrieve the open circuit if the circuit map before
// the circuit map is restarted.
circuit2 := circuitMap.LookupOpenCircuit(keystone.OutKey)
if !reflect.DeepEqual(circuit, circuit2) {
t.Fatalf("unexpected open circuit: got %v, want %v",
circuit2, circuit)
}
// Open the circuit for a second time, which should fail due to a
// duplicate keystone
err = circuitMap.OpenCircuits(keystone)
if err != htlcswitch.ErrDuplicateKeystone {
t.Fatalf("failed to open circuits: %v", err)
}
// Close the open circuit for the first time, which should succeed.
_, err = circuitMap.FailCircuit(circuit.Incoming)
if err != nil {
t.Fatalf("unable to close unopened circuit")
}
// Closing the circuit a second time should result in a failure.
_, err = circuitMap.FailCircuit(circuit.Incoming)
if err != htlcswitch.ErrCircuitClosing {
t.Fatalf("unable to close unopened circuit")
}
// Now, restart the circuit map, which will cause the opened circuit to
// be loaded from disk. Since we set the keystone on this circuit, it
// should be restored as such in memory.
//
// NOTE: The channel db doesn't have any channel data, so no keystones
// will be trimmed.
cdb, circuitMap = restartCircuitMap(t, cdb)
// Close the open circuit for the first time, which should succeed.
_, err = circuitMap.FailCircuit(circuit.Incoming)
if err != nil {
t.Fatalf("unable to close unopened circuit")
}
// Closing the circuit a second time should result in a failure.
_, err = circuitMap.FailCircuit(circuit.Incoming)
if err != htlcswitch.ErrCircuitClosing {
t.Fatalf("unable to close unopened circuit")
}
}
// TestCircuitMapCloseUnopenedCircuit tests that closing an unopened circuit
// allows at most semantics, and that the close is not persisted across
// restarts.
func TestCircuitMapCloseUnopenedCircuit(t *testing.T) {
t.Parallel()
var (
chan1 = lnwire.NewShortChanIDFromInt(1)
circuitMap htlcswitch.CircuitMap
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
}
// First we will try to add an new circuit to the circuit map, this
// should succeed.
_, err = circuitMap.CommitCircuits(circuit)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
// Close the open circuit for the first time, which should succeed.
_, err = circuitMap.FailCircuit(circuit.Incoming)
if err != nil {
t.Fatalf("unable to close unopened circuit")
}
// Closing the circuit a second time should result in a failure.
_, err = circuitMap.FailCircuit(circuit.Incoming)
if err != htlcswitch.ErrCircuitClosing {
t.Fatalf("unable to close unopened circuit")
}
// Now, restart the circuit map, which will result in the circuit being
// reopened, since no attempt to delete the circuit was made.
cdb, circuitMap = restartCircuitMap(t, cdb)
// Close the open circuit for the first time, which should succeed.
_, err = circuitMap.FailCircuit(circuit.Incoming)
if err != nil {
t.Fatalf("unable to close unopened circuit")
}
// Closing the circuit a second time should result in a failure.
_, err = circuitMap.FailCircuit(circuit.Incoming)
if err != htlcswitch.ErrCircuitClosing {
t.Fatalf("unable to close unopened circuit")
}
}
// TestCircuitMapDeleteUnopenedCircuit checks that an unopened circuit can be
// removed persistently from the circuit map.
func TestCircuitMapDeleteUnopenedCircuit(t *testing.T) {
t.Parallel()
var (
chan1 = lnwire.NewShortChanIDFromInt(1)
circuitMap htlcswitch.CircuitMap
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
}
// First we will try to add an new circuit to the circuit map, this
// should succeed.
_, err = circuitMap.CommitCircuits(circuit)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
// Close the open circuit for the first time, which should succeed.
_, err = circuitMap.FailCircuit(circuit.Incoming)
if err != nil {
t.Fatalf("unable to close unopened circuit")
}
err = circuitMap.DeleteCircuits(circuit.Incoming)
if err != nil {
t.Fatalf("unable to close unopened circuit")
}
// Check that we can retrieve the open circuit if the circuit map before
// the circuit map is restarted.
circuit2 := circuitMap.LookupCircuit(circuit.Incoming)
if circuit2 != nil {
t.Fatalf("unexpected open circuit: got %v, want %v",
circuit2, nil)
}
// Now, restart the circuit map, and check that the deletion survived
// the restart.
cdb, circuitMap = restartCircuitMap(t, cdb)
circuit2 = circuitMap.LookupCircuit(circuit.Incoming)
if circuit2 != nil {
t.Fatalf("unexpected open circuit: got %v, want %v",
circuit2, nil)
}
}
// TestCircuitMapDeleteUnopenedCircuit checks that an open circuit can be
// removed persistently from the circuit map.
func TestCircuitMapDeleteOpenCircuit(t *testing.T) {
t.Parallel()
var (
chan1 = lnwire.NewShortChanIDFromInt(1)
chan2 = lnwire.NewShortChanIDFromInt(2)
circuitMap htlcswitch.CircuitMap
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
}
// First we will try to add an new circuit to the circuit map, this
// should succeed.
_, err = circuitMap.CommitCircuits(circuit)
if err != nil {
t.Fatalf("failed to commit circuits: %v", err)
}
keystone := htlcswitch.Keystone{
InKey: circuit.Incoming,
OutKey: htlcswitch.CircuitKey{
ChanID: chan2,
HtlcID: 2,
},
}
// Open the circuit for the first time.
err = circuitMap.OpenCircuits(keystone)
if err != nil {
t.Fatalf("failed to open circuits: %v", err)
}
// Close the open circuit for the first time, which should succeed.
_, err = circuitMap.FailCircuit(circuit.Incoming)
if err != nil {
t.Fatalf("unable to close unopened circuit")
}
// Persistently remove the circuit identified by incoming chan id.
err = circuitMap.DeleteCircuits(circuit.Incoming)
if err != nil {
t.Fatalf("unable to close unopened circuit")
}
// Check that we can no longer retrieve the open circuit.
circuit2 := circuitMap.LookupOpenCircuit(keystone.OutKey)
if circuit2 != nil {
t.Fatalf("unexpected open circuit: got %v, want %v",
circuit2, nil)
}
// Now, restart the circuit map, and check that the deletion survived
// the restart.
cdb, circuitMap = restartCircuitMap(t, cdb)
circuit2 = circuitMap.LookupOpenCircuit(keystone.OutKey)
if circuit2 != nil {
t.Fatalf("unexpected open circuit: got %v, want %v",
circuit2, nil)
}
}

@ -47,7 +47,7 @@ type ChannelLink interface {
//
// NOTE: This function MUST be non-blocking (or block as little as
// possible).
HandleSwitchPacket(*htlcPacket)
HandleSwitchPacket(*htlcPacket) error
// HandleChannelUpdate handles the htlc requests as settle/add/fail
// which sent to us from remote peer we have a channel with.
@ -98,6 +98,10 @@ type ChannelLink interface {
// will use this function in forwarding decisions accordingly.
EligibleToForward() bool
// AttachMailBox delivers an active MailBox to the link. The MailBox may
// have buffered messages.
AttachMailBox(MailBox)
// Start/Stop are used to initiate the start/stop of the channel link
// functioning.
Start() error

@ -31,6 +31,10 @@ const (
expiryGraceDelta = 2
)
// ErrInternalLinkFailure is a generic error returned to the remote party so as
// to obfuscate the true failure.
var ErrInternalLinkFailure = errors.New("internal link failure")
// ForwardingPolicy describes the set of constraints that a given ChannelLink
// is to adhere to when forwarding HTLC's. For each incoming HTLC, this set of
// constraints will be consulted in order to ensure that adequate fees are
@ -119,13 +123,21 @@ type ChannelLinkConfig struct {
// targeted at a given ChannelLink concrete interface implementation.
FwrdingPolicy ForwardingPolicy
// Switch is a subsystem which is used to forward the incoming HTLC
// packets according to the encoded hop forwarding information
// contained in the forwarding blob within each HTLC.
//
// TODO(roasbeef): remove in favor of simple ForwardPacket closure func
// Circuits provides restricted access to the switch's circuit map,
// allowing the link to open and close circuits.
Circuits CircuitModifier
// Switch provides a reference to the HTLC switch, we only use this in
// testing to access circuit operations not typically exposed by the
// CircuitModifier.
// TODO(conner): remove after refactoring htlcswitch testing framework.
Switch *Switch
// ForwardPackets attempts to forward the batch of htlcs through the
// switch. Any failed packets will be returned to the provided
// ChannelLink.
ForwardPackets func(...*htlcPacket) chan error
// DecodeHopIterator function is responsible for decoding HTLC Sphinx
// onion blob, and creating hop iterator which will give us next
// destination of HTLC.
@ -209,9 +221,20 @@ type ChannelLinkConfig struct {
// coalesced into a single commit.
BatchTicker Ticker
// FwdPkgGCTicker is the ticker determining the frequency at which
// garbage collection of forwarding packages occurs. We use a time-based
// approach, as opposed to block epochs, as to not hinder syncing.
FwdPkgGCTicker Ticker
// BatchSize is the max size of a batch of updates done to the link
// before we do a state update.
BatchSize uint32
// UnsafeReplay will cause a link to replay the adds in its latest
// commitment txn after the link is restarted. This should only be used
// in testing, it is here to ensure the sphinx replay detection on the
// receiving node is persistent.
UnsafeReplay bool
}
// channelLink is the service which drives a channel's commitment update
@ -237,6 +260,13 @@ type channelLink struct {
// use this information to govern decisions based on HTLC timeouts.
bestHeight uint32
// keystoneBatch represents a volatile list of keystones that must be
// written before attempting to sign the next commitment txn.
keystoneBatch []Keystone
openedCircuits []CircuitKey
closedCircuits []CircuitKey
// channel is a lightning network channel to which we apply htlc
// updates.
channel *lnwallet.LightningChannel
@ -252,11 +282,15 @@ type channelLink struct {
// been processed because of the commitment transaction overflow.
overflowQueue *packetQueue
// startMailBox directs whether or not to start the mailbox when
// starting the link. It may have already been started by the switch.
startMailBox bool
// mailBox is the main interface between the outside world and the
// link. All incoming messages will be sent over this mailBox. Messages
// include new updates from our connected peer, and new packets to be
// forwarded sent by the switch.
mailBox *memoryMailBox
mailBox MailBox
// upstream is a channel that new messages sent from the remote peer to
// the local peer will be sent across.
@ -295,11 +329,10 @@ type channelLink struct {
func NewChannelLink(cfg ChannelLinkConfig, channel *lnwallet.LightningChannel,
currentHeight uint32) ChannelLink {
link := &channelLink{
return &channelLink{
cfg: cfg,
channel: channel,
shortChanID: channel.ShortChanID(),
mailBox: newMemoryMailBox(),
linkControl: make(chan interface{}),
// TODO(roasbeef): just do reserve here?
logCommitTimer: time.NewTimer(300 * time.Millisecond),
@ -308,11 +341,6 @@ func NewChannelLink(cfg ChannelLinkConfig, channel *lnwallet.LightningChannel,
htlcUpdates: make(chan []channeldb.HTLC),
quit: make(chan struct{}),
}
link.upstream = link.mailBox.MessageOutBox()
link.downstream = link.mailBox.PacketOutBox()
return link
}
// A compile time check to ensure channelLink implements the ChannelLink
@ -347,7 +375,7 @@ func (l *channelLink) Start() error {
}
}()
l.mailBox.Start()
l.mailBox.ResetMessages()
l.overflowQueue.Start()
l.wg.Add(1)
@ -374,7 +402,6 @@ func (l *channelLink) Stop() {
l.channel.Stop()
l.mailBox.Stop()
l.overflowQueue.Stop()
close(l.quit)
@ -500,10 +527,16 @@ func (l *channelLink) syncChanStates() error {
log.Infof("Received re-establishment message from remote side "+
"for channel(%v)", l.channel.ChannelPoint())
var (
openedCircuits []CircuitKey
closedCircuits []CircuitKey
)
// We've just received a ChnSync message from the remote party,
// so we'll process the message in order to determine if we
// need to re-transmit any messages to the remote party.
msgsToReSend, _, _, err = l.channel.ProcessChanSyncMsg(remoteChanSyncMsg)
msgsToReSend, openedCircuits, closedCircuits, err =
l.channel.ProcessChanSyncMsg(remoteChanSyncMsg)
if err != nil {
// TODO(roasbeef): check concrete type of error, act
// accordingly
@ -511,6 +544,17 @@ func (l *channelLink) syncChanStates() error {
"message: %v", err)
}
// Repopulate any identifiers for circuits that may have been
// opened or unclosed.
l.openedCircuits = openedCircuits
l.closedCircuits = closedCircuits
// Ensure that all packets have been have been removed from the
// link's mailbox.
if err := l.ackDownStreamPackets(true); err != nil {
return err
}
if len(msgsToReSend) > 0 {
log.Infof("Sending %v updates to synchronize the "+
"state for ChannelPoint(%v)", len(msgsToReSend),
@ -532,79 +576,128 @@ func (l *channelLink) syncChanStates() error {
"deadline")
}
// In order to prep for the fragment below, we'll note if we
// retransmitted any HTLC's settles earlier. We'll track them by the
// HTLC index of the remote party in order to avoid erroneously sending
// a duplicate settle.
htlcsSettled := make(map[uint64]struct{})
for _, msg := range msgsToReSend {
settleMsg, ok := msg.(*lnwire.UpdateFulfillHTLC)
if !ok {
// If this isn't a settle message, then we'll skip it.
continue
}
return nil
}
// Otherwise, we'll note the ID of the HTLC we're settling so we
// don't duplicate it below.
htlcsSettled[settleMsg.ID] = struct{}{}
// resolveFwdPkgs loads any forwarding packages for this link from disk, and
// reprocesses them in order. The primary goal is to make sure that any HTLCs we
// previously received are reinstated in memory, and forwarded to the switch if
// necessary. After a restart, this will also delete any previously completed
// packages.
func (l *channelLink) resolveFwdPkgs() error {
fwdPkgs, err := l.channel.LoadFwdPkgs()
if err != nil {
return err
}
// Now that we've synchronized our state, we'll check to see if
// there're any HTLC's that we received, but weren't able to settle
// directly the last time we were active. If we find any, then we'll
// send the settle message, then being to initiate a state transition.
//
// TODO(roasbeef): can later just inspect forwarding package
activeHTLCs := l.channel.ActiveHtlcs()
for _, htlc := range activeHTLCs {
if !htlc.Incoming {
continue
}
l.debugf("loaded %d fwd pks", len(fwdPkgs))
// Before we attempt to settle this HTLC, we'll check to see if
// we just re-sent it as part of the channel sync. If so, then
// we'll skip it.
if _, ok := htlcsSettled[htlc.HtlcIndex]; ok {
continue
}
// Now we'll check to if we we actually know the preimage if we
// don't then we'll skip it.
preimage, ok := l.cfg.PreimageCache.LookupPreimage(htlc.RHash[:])
if !ok {
continue
}
// At this point, we've found an unsettled HTLC that we know
// the preimage to, so we'll send a settle message to the
// remote party.
var p [32]byte
copy(p[:], preimage)
err := l.channel.SettleHTLC(p, htlc.HtlcIndex, nil, nil, nil)
var needUpdate bool
for _, fwdPkg := range fwdPkgs {
hasUpdate, err := l.resolveFwdPkg(fwdPkg)
if err != nil {
l.fail("unable to settle htlc: %v", err)
return err
}
// We'll now mark the HTLC as settled in the invoice database,
// then send the settle message to the remote party.
err = l.cfg.Registry.SettleInvoice(htlc.RHash)
if err != nil {
l.fail("unable to settle invoice: %v", err)
return err
}
l.batchCounter++
l.cfg.Peer.SendMessage(&lnwire.UpdateFulfillHTLC{
ChanID: l.ChanID(),
ID: htlc.HtlcIndex,
PaymentPreimage: p,
})
needUpdate = needUpdate || hasUpdate
}
// If any of our reprocessing steps require an update to the commitment
// txn, we initiate a state transition to capture all relevant changes.
if needUpdate {
return l.updateCommitTx()
}
return nil
}
// resolveFwdPkg interprets the FwdState of the provided package, either
// reprocesses any outstanding htlcs in the package, or performs garbage
// collection on the package.
func (l *channelLink) resolveFwdPkg(fwdPkg *channeldb.FwdPkg) (bool, error) {
// Remove any completed packages to clear up space.
if fwdPkg.State == channeldb.FwdStateCompleted {
l.debugf("removing completed fwd pkg for height=%d",
fwdPkg.Height)
err := l.channel.RemoveFwdPkg(fwdPkg.Height)
if err != nil {
l.errorf("unable to remove fwd pkg for height=%d: %v",
fwdPkg.Height, err)
return false, err
}
}
// Otherwise this is either a new package or one has gone through
// processing, but contains htlcs that need to be restored in memory. We
// replay this forwarding package to make sure our local mem state is
// resurrected, we mimic any original responses back to the remote
// party, and reforward the relevant HTLCs to the switch.
// If the package is fully acked but not completed, it must still have
// settles and fails to propagate.
if !fwdPkg.SettleFailFilter.IsFull() {
settleFails := lnwallet.PayDescsFromRemoteLogUpdates(
fwdPkg.Source, fwdPkg.Height, fwdPkg.SettleFails,
)
l.processRemoteSettleFails(fwdPkg, settleFails)
}
// Finally, replay *ALL ADDS* in this forwarding package. The downstream
// logic is able to filter out any duplicates, but we must shove the
// entire, original set of adds down the pipeline so that the batch of
// adds presented to the sphinx router does not ever change.
var needUpdate bool
if !fwdPkg.AckFilter.IsFull() {
adds := lnwallet.PayDescsFromRemoteLogUpdates(
fwdPkg.Source, fwdPkg.Height, fwdPkg.Adds,
)
needUpdate = l.processRemoteAdds(fwdPkg, adds)
}
return needUpdate, nil
}
// fwdPkgGarbager periodically reads all forwarding packages from disk and
// removes those that can be discarded. It is safe to do this entirely in the
// background, since all state is coordinated on disk. This also ensures the
// link can continue to process messages and interleave database accesses.
//
// NOTE: This MUST be run as a goroutine.
func (l *channelLink) fwdPkgGarbager() {
defer l.wg.Done()
fwdPkgGcTick := l.cfg.FwdPkgGCTicker.Start()
defer l.cfg.FwdPkgGCTicker.Stop()
for {
select {
case <-fwdPkgGcTick:
fwdPkgs, err := l.channel.LoadFwdPkgs()
if err != nil {
l.warnf("unable to load fwdpkgs for gc: %v", err)
continue
}
// TODO(conner): batch removal of forward packages.
for _, fwdPkg := range fwdPkgs {
if fwdPkg.State != channeldb.FwdStateCompleted {
continue
}
err = l.channel.RemoveFwdPkg(fwdPkg.Height)
if err != nil {
l.warnf("unable to remove fwd pkg "+
"for height=%d: %v",
fwdPkg.Height, err)
}
}
case <-l.quit:
return
}
}
}
// htlcManager is the primary goroutine which drives a channel's commitment
// update state-machine in response to messages received via several channels.
// This goroutine reads messages from the upstream (remote) peer, and also from
@ -625,6 +718,24 @@ func (l *channelLink) htlcManager() {
log.Infof("HTLC manager for ChannelPoint(%v) started, "+
"bandwidth=%v", l.channel.ChannelPoint(), l.Bandwidth())
// Before handling any messages, revert any circuits that were marked
// open in the switch's circuit map, but did not make it into a
// commitment txn. We use the next local htlc index as the cut off
// point, since all indexes below that are committed.
//
// NOTE: This is automatically done by the switch when it starts up, but
// is necessary to prevent inconsistencies in the case that the link
// flaps. This is a result of a link's life-cycle being shorter than
// that of the switch.
localHtlcIndex := l.channel.LocalHtlcIndex()
err := l.cfg.Circuits.TrimOpenCircuits(l.ShortChanID(), localHtlcIndex)
if err != nil {
l.errorf("unable to trim circuits above local htlc index %d: %v",
localHtlcIndex, err)
l.fail(ErrInternalLinkFailure.Error())
return
}
// TODO(roasbeef): need to call wipe chan whenever D/C?
// If this isn't the first time that this channel link has been
@ -634,11 +745,34 @@ func (l *channelLink) htlcManager() {
if l.cfg.SyncStates {
// TODO(roasbeef): need to ensure haven't already settled?
if err := l.syncChanStates(); err != nil {
l.errorf("unable to synchronize channel states: %v", err)
l.fail(err.Error())
return
}
}
// With the channel states synced, we now reset the mailbox to ensure we
// start processing all unacked packets in order. This is done here to
// ensure that all acknowledgments that occur during channel
// resynchronization have taken affect, causing us only to pull unacked
// packets after starting to read from the downstream mailbox.
l.mailBox.ResetPackets()
// After cleaning up any memory pertaining to incoming packets, we now
// replay our forwarding packages to handle any htlcs that can be
// processed locally, or need to be forwarded out to the switch.
if err := l.resolveFwdPkgs(); err != nil {
l.errorf("unable to resolve fwd pkgs: %v", err)
l.fail(ErrInternalLinkFailure.Error())
return
}
// With our link's in-memory state fully reconstructed, spawn a
// goroutine to manage the reclamation of disk space occupied by
// completed forwarding packages.
l.wg.Add(1)
go l.fwdPkgGarbager()
batchTick := l.cfg.BatchTicker.Start()
defer l.cfg.BatchTicker.Stop()
@ -815,11 +949,13 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
var isSettle bool
switch htlc := pkt.htlc.(type) {
case *lnwire.UpdateAddHTLC:
// A new payment has been initiated via the downstream channel,
// so we add the new HTLC to our local log, then update the
// commitment chains.
htlc.ChanID = l.ChanID()
index, err := l.channel.AddHTLC(htlc, nil)
openCircuitRef := pkt.inKey()
index, err := l.channel.AddHTLC(htlc, &openCircuitRef)
if err != nil {
switch err {
@ -871,17 +1007,28 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
failPkt := &htlcPacket{
incomingChanID: pkt.incomingChanID,
incomingHTLCID: pkt.incomingHTLCID,
amount: htlc.Amount,
isRouted: true,
circuit: pkt.circuit,
sourceRef: pkt.sourceRef,
hasSource: true,
localFailure: localFailure,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
}
// TODO(roasbeef): need to identify if sent
// from switch so don't need to obfuscate
go l.cfg.Switch.forward(failPkt)
go l.forwardBatch(failPkt)
// Remove this packet from the link's mailbox,
// this prevents it from being reprocessed if
// the link restarts and resets it mailbox. If
// this response doesn't make it back to the
// originating link, it will be rejected upon
// attempting to reforward the Add to the
// switch, since the circuit was never fully
// opened, and the forwarding package shows it
// as unacknowledged.
l.mailBox.AckPacket(pkt.inKey())
return
}
}
@ -890,39 +1037,41 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
"local_log_index=%v, batch_size=%v",
htlc.PaymentHash[:], index, l.batchCounter+1)
// Create circuit (remember the path) in order to forward
// settle/fail packet back.
l.cfg.Switch.addCircuit(&PaymentCircuit{
PaymentHash: htlc.PaymentHash,
IncomingChanID: pkt.incomingChanID,
IncomingHTLCID: pkt.incomingHTLCID,
IncomingAmt: pkt.incomingHtlcAmt,
OutgoingChanID: l.ShortChanID(),
OutgoingHTLCID: index,
OutgoingAmt: htlc.Amount,
ErrorEncrypter: pkt.obfuscator,
})
pkt.outgoingChanID = l.ShortChanID()
pkt.outgoingHTLCID = index
htlc.ID = index
l.debugf("Queueing keystone of ADD open circuit: %s->%s",
pkt.inKey(), pkt.outKey())
l.openedCircuits = append(l.openedCircuits, pkt.inKey())
l.keystoneBatch = append(l.keystoneBatch, pkt.keystone())
l.cfg.Peer.SendMessage(htlc)
case *lnwire.UpdateFulfillHTLC:
// An HTLC we forward to the switch has just settled somewhere
// upstream. Therefore we settle the HTLC within the our local
// state machine.
err := l.channel.SettleHTLC(
closedCircuitRef := pkt.inKey()
if err := l.channel.SettleHTLC(
htlc.PaymentPreimage,
pkt.incomingHTLCID,
nil,
nil,
nil,
)
if err != nil {
pkt.sourceRef,
pkt.destRef,
&closedCircuitRef,
); err != nil {
// TODO(roasbeef): broadcast on-chain
l.fail("unable to settle incoming HTLC: %v", err)
return
}
l.debugf("Queueing removal of SETTLE closed circuit: %s->%s",
pkt.inKey(), pkt.outKey())
l.closedCircuits = append(l.closedCircuits, pkt.inKey())
// With the HTLC settled, we'll need to populate the wire
// message to target the specific channel and HTLC to be
// cancelled.
@ -937,18 +1086,23 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
case *lnwire.UpdateFailHTLC:
// An HTLC cancellation has been triggered somewhere upstream,
// we'll remove then HTLC from our local state machine.
err := l.channel.FailHTLC(
closedCircuitRef := pkt.inKey()
if err := l.channel.FailHTLC(
pkt.incomingHTLCID,
htlc.Reason,
nil,
nil,
nil,
)
if err != nil {
pkt.sourceRef,
pkt.destRef,
&closedCircuitRef,
); err != nil {
log.Errorf("unable to cancel HTLC: %v", err)
return
}
l.debugf("Queueing removal of FAIL closed circuit: %s->%s",
pkt.inKey(), pkt.outKey())
l.closedCircuits = append(l.closedCircuits, pkt.inKey())
// With the HTLC removed, we'll need to populate the wire
// message to target the specific channel and HTLC to be
// cancelled. The "Reason" field will have already been set
@ -1141,32 +1295,21 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
// We've received a revocation from the remote chain, if valid,
// this moves the remote chain forward, and expands our
// revocation window.
_, adds, settleFails, err := l.channel.ReceiveRevocation(msg)
fwdPkg, adds, settleFails, err := l.channel.ReceiveRevocation(msg)
if err != nil {
l.fail("unable to accept revocation: %v", err)
return
}
// After we treat HTLCs as included in both remote/local
// commitment transactions they might be safely propagated over
// htlc switch or settled if our node was last node in htlc
// path.
htlcs := append(settleFails, adds...)
htlcsToForward := l.processLockedInHtlcs(htlcs)
go func() {
log.Debugf("ChannelPoint(%v) forwarding %v HTLC's",
l.channel.ChannelPoint(), len(htlcsToForward))
for _, packet := range htlcsToForward {
if err := l.cfg.Switch.forward(packet); err != nil {
// TODO(roasbeef): cancel back htlc
// under certain conditions?
log.Errorf("channel link(%v): "+
"unhandled error while forwarding "+
"htlc packet over htlc "+
"switch: %v", l, err)
}
l.processRemoteSettleFails(fwdPkg, settleFails)
needUpdate := l.processRemoteAdds(fwdPkg, adds)
if needUpdate {
if err := l.updateCommitTx(); err != nil {
l.fail("unable to update commitment: %v", err)
return
}
}()
}
case *lnwire.UpdateFee:
// We received fee update from peer. If we are the initiator we
@ -1179,10 +1322,97 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
}
}
// ackDownStreamPackets is responsible for removing htlcs from a link's
// mailbox for packets delivered from server, and cleaning up any circuits
// closed by signing a previous commitment txn. This method ensures that the
// circuits are removed from the circuit map before removing them from the
// link's mailbox, otherwise it could be possible for some circuit to be missed
// if this link flaps.
//
// The `forgive` flag allows this method to tolerate restarts, and ignores
// errors that could be caused by a previous circuit deletion. Under normal
// operation, this is set to false so that we would fail the link if we were
// unable to remove a circuit.
func (l *channelLink) ackDownStreamPackets(forgive bool) error {
// First, remove the downstream Add packets that were included in the
// previous commitment signature. This will prevent the Adds from being
// replayed if this link disconnects.
for _, inKey := range l.openedCircuits {
// In order to test the sphinx replay logic of the remote party,
// unsafe replay does not acknowledge the packets from the
// mailbox. We can then force a replay of any Add packets held
// in memory by disconnecting and reconnecting the link.
if l.cfg.UnsafeReplay {
continue
}
l.debugf("Removing Add packet %s from mailbox", inKey)
l.mailBox.AckPacket(inKey)
}
// Now, we will delete all circuits closed by the previous commitment
// signature, which is the result of downstream Settle/Fail packets. We
// batch them here to ensure circuits are closed atomically and for
// performance.
err := l.cfg.Circuits.DeleteCircuits(l.closedCircuits...)
switch err {
case nil:
// Successful deletion.
case ErrUnknownCircuit:
if forgive {
// After a restart, we may have already removed this
// circuit. Since it shouldn't be possible for a circuit
// to be closed by different htlcs, we assume this error
// signals that the whole batch was successfully
// removed.
l.warnf("Forgiving unknown circuit error after " +
"attempting deletion, circuit was probably " +
"removed before shutting down.")
break
}
return err
default:
l.errorf("unable to delete %d circuits: %v",
len(l.closedCircuits), err)
return err
}
// With the circuits removed from memory and disk, we now ack any
// Settle/Fails in the mailbox to ensure they do not get redelivered
// after startup. If forgive is enabled and we've reached this point,
// the circuits must have been removed at some point, so it is now safe
// to unqueue the corresponding Settle/Fails.
for _, inKey := range l.closedCircuits {
l.debugf("Removing Fail/Settle packet %s from mailbox", inKey)
l.mailBox.AckPacket(inKey)
}
// Lastly, reset our buffers to be empty while keeping any acquired
// growth in the backing array.
l.openedCircuits = l.openedCircuits[:0]
l.closedCircuits = l.closedCircuits[:0]
return nil
}
// updateCommitTx signs, then sends an update to the remote peer adding a new
// commitment to their commitment chain which includes all the latest updates
// we've received+processed up to this point.
func (l *channelLink) updateCommitTx() error {
// Preemptively write all pending keystones to disk, just in case the
// HTLCs we have in memory are included in the subsequent attempt to
// sign a commitment state.
err := l.cfg.Circuits.OpenCircuits(l.keystoneBatch...)
if err != nil {
return err
}
// Reset the batch, but keep the backing buffer to avoid reallocating.
l.keystoneBatch = l.keystoneBatch[:0]
theirCommitSig, htlcSigs, err := l.channel.SignNextCommitment()
if err == lnwallet.ErrNoWindow {
log.Tracef("revocation window exhausted, unable to send %v",
@ -1192,6 +1422,10 @@ func (l *channelLink) updateCommitTx() error {
return err
}
if err := l.ackDownStreamPackets(false); err != nil {
return err
}
commitSig := &lnwire.CommitSig{
ChanID: l.ChanID(),
CommitSig: theirCommitSig,
@ -1231,8 +1465,6 @@ func (l *channelLink) Peer() Peer {
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) ShortChanID() lnwire.ShortChannelID {
l.RLock()
defer l.RUnlock()
return l.shortChanID
}
@ -1301,6 +1533,17 @@ func (l *channelLink) Bandwidth() lnwire.MilliSatoshi {
return linkBandwidth - reserve
}
// AttachMailBox updates the current mailbox used by this link, and hooks up the
// mailbox's message and packet outboxes to the link's upstream and downstream
// chans, respectively.
func (l *channelLink) AttachMailBox(mailbox MailBox) {
l.Lock()
l.mailBox = mailbox
l.upstream = mailbox.MessageOutBox()
l.downstream = mailbox.PacketOutBox()
l.Unlock()
}
// policyUpdate is a message sent to a channel link when an outside sub-system
// wishes to update the current forwarding policy.
type policyUpdate struct {
@ -1357,8 +1600,11 @@ func (l *channelLink) String() string {
// another peer or if the update was created by user
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) HandleSwitchPacket(packet *htlcPacket) {
l.mailBox.AddPacket(packet)
func (l *channelLink) HandleSwitchPacket(pkt *htlcPacket) error {
l.tracef("received switch packet inkey=%v, outkey=%v",
pkt.inKey(), pkt.outKey())
l.mailBox.AddPacket(pkt)
return nil
}
// HandleChannelUpdate handles the htlc requests as settle/add/fail which sent
@ -1379,8 +1625,8 @@ func (l *channelLink) updateChannelFee(feePerKw lnwallet.SatPerKWeight) error {
// We skip sending the UpdateFee message if the channel is not
// currently eligible to forward messages.
if !l.EligibleToForward() {
log.Debugf("ChannelPoint(%v): skipping fee update for " +
"inactive channel")
log.Debugf("ChannelPoint(%v): skipping fee update for "+
"inactive channel", l.ChanID())
return nil
}
@ -1398,22 +1644,29 @@ func (l *channelLink) updateChannelFee(feePerKw lnwallet.SatPerKWeight) error {
return l.updateCommitTx()
}
// processLockedInHtlcs serially processes each of the log updates which have
// been "locked-in". An HTLC is considered locked-in once it has been fully
// committed to in both the remote and local commitment state. Once a channel
// updates is locked-in, then it can be acted upon, meaning: settling HTLCs,
// cancelling them, or forwarding new HTLCs to the next hop.
func (l *channelLink) processLockedInHtlcs(
paymentDescriptors []*lnwallet.PaymentDescriptor) []*htlcPacket {
// processRemoteSettleFails accepts a batch of settle/fail payment descriptors
// after receiving a revocation from the remote party, and reprocesses them in
// the context of the provided forwarding package. Any settles or fails that
// have already been acknowledged in the forwarding package will not be sent to
// the switch.
func (l *channelLink) processRemoteSettleFails(fwdPkg *channeldb.FwdPkg,
settleFails []*lnwallet.PaymentDescriptor) {
var (
needUpdate bool
packetsToForward []*htlcPacket
)
if len(settleFails) == 0 {
return
}
log.Debugf("ChannelLink(%v): settle-fail-filter %v",
l.ShortChanID(), fwdPkg.SettleFailFilter)
var switchPackets []*htlcPacket
for i, pd := range settleFails {
// Skip any settles or fails that have already been acknowledged
// by the incoming link that originated the forwarded Add.
if fwdPkg.SettleFailFilter.Contains(uint16(i)) {
continue
}
for _, pd := range paymentDescriptors {
// TODO(roasbeef): rework log entries to a shared
// interface.
switch pd.EntryType {
// A settle for an HTLC we previously forwarded HTLC has been
@ -1423,7 +1676,7 @@ func (l *channelLink) processLockedInHtlcs(
settlePacket := &htlcPacket{
outgoingChanID: l.ShortChanID(),
outgoingHTLCID: pd.ParentIndex,
amount: pd.Amount,
destRef: pd.DestRef,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: pd.RPreimage,
},
@ -1432,7 +1685,7 @@ func (l *channelLink) processLockedInHtlcs(
// Add the packet to the batch to be forwarded, and
// notify the overflow queue that a spare spot has been
// freed up within the commitment state.
packetsToForward = append(packetsToForward, settlePacket)
switchPackets = append(switchPackets, settlePacket)
l.overflowQueue.SignalFreeSlot()
// A failureCode message for a previously forwarded HTLC has
@ -1445,7 +1698,7 @@ func (l *channelLink) processLockedInHtlcs(
failPacket := &htlcPacket{
outgoingChanID: l.ShortChanID(),
outgoingHTLCID: pd.ParentIndex,
amount: pd.Amount,
destRef: pd.DestRef,
htlc: &lnwire.UpdateFailHTLC{
Reason: lnwire.OpaqueReason(pd.FailReason),
},
@ -1454,8 +1707,83 @@ func (l *channelLink) processLockedInHtlcs(
// Add the packet to the batch to be forwarded, and
// notify the overflow queue that a spare spot has been
// freed up within the commitment state.
packetsToForward = append(packetsToForward, failPacket)
switchPackets = append(switchPackets, failPacket)
l.overflowQueue.SignalFreeSlot()
}
}
go l.forwardBatch(switchPackets...)
}
// processRemoteAdds serially processes each of the Add payment descriptors
// which have been "locked-in" by receiving a revocation from the remote party.
// The forwarding package provided instructs how to process this batch,
// indicating whether this is the first time these Adds are being processed, or
// whether we are reprocessing as a result of a failure or restart. Adds that
// have already been acknowledged in the forwarding package will be ignored.
func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
lockedInHtlcs []*lnwallet.PaymentDescriptor) bool {
l.tracef("processing %d remote adds for height %d",
len(lockedInHtlcs), fwdPkg.Height)
decodeReqs := make([]DecodeHopIteratorRequest, 0, len(lockedInHtlcs))
for _, pd := range lockedInHtlcs {
switch pd.EntryType {
// TODO(conner): remove type switch?
case lnwallet.Add:
// Before adding the new htlc to the state machine,
// parse the onion object in order to obtain the
// routing information with DecodeHopIterator function
// which process the Sphinx packet.
onionReader := bytes.NewReader(pd.OnionBlob)
req := DecodeHopIteratorRequest{
OnionReader: onionReader,
RHash: pd.RHash[:],
IncomingCltv: pd.Timeout,
}
decodeReqs = append(decodeReqs, req)
}
}
// Atomically decode the incoming htlcs, simultaneously checking for
// replay attempts. A particular index in the returned, spare list of
// channel iterators should only be used if the failure code at the same
// index is lnwire.FailCodeNone.
decodeResps, sphinxErr := l.cfg.DecodeHopIterators(
fwdPkg.ID(), decodeReqs,
)
if sphinxErr != nil {
l.errorf("unable to decode hop iterators: %v", sphinxErr)
l.fail(ErrInternalLinkFailure.Error())
return false
}
var (
needUpdate bool
switchPackets []*htlcPacket
)
for i, pd := range lockedInHtlcs {
idx := uint16(i)
if fwdPkg.State == channeldb.FwdStateProcessed &&
fwdPkg.AckFilter.Contains(idx) {
// If this index is already found in the ack filter, the
// response to this forwarding decision has already been
// committed by one of our commitment txns. ADDs in this
// state are waiting for the rest of the fwding package
// to get acked before being garbage collected.
continue
}
// TODO(roasbeef): rework log entries to a shared
// interface.
switch pd.EntryType {
// An incoming HTLC add has been full-locked in. As a result we
// can now examine the forwarding details of the HTLC, and the
@ -1472,23 +1800,13 @@ func (l *channelLink) processLockedInHtlcs(
// parse the onion object in order to obtain the
// routing information with DecodeHopIterator function
// which process the Sphinx packet.
//
// We include the payment hash of the htlc as it's
// authenticated within the Sphinx packet itself as
// associated data in order to thwart attempts a replay
// attacks. In the case of a replay, an attacker is
// *forced* to use the same payment hash twice, thereby
// losing their money entirely.
onionReader := bytes.NewReader(onionBlob[:])
chanIterator, failureCode := l.cfg.DecodeHopIterator(
onionReader, pd.RHash[:], pd.Timeout,
)
chanIterator, failureCode := decodeResps[i].Result()
if failureCode != lnwire.CodeNone {
// If we're unable to process the onion blob
// than we should send the malformed htlc error
// to payment sender.
l.sendMalformedHTLCError(pd.HtlcIndex, failureCode,
onionBlob[:])
onionBlob[:], pd.SourceRef)
needUpdate = true
log.Errorf("unable to decode onion hop "+
@ -1507,7 +1825,7 @@ func (l *channelLink) processLockedInHtlcs(
// than we should send the malformed htlc error
// to payment sender.
l.sendMalformedHTLCError(pd.HtlcIndex, failureCode,
onionBlob[:])
onionBlob[:], pd.SourceRef)
needUpdate = true
log.Errorf("unable to decode onion "+
@ -1520,6 +1838,7 @@ func (l *channelLink) processLockedInHtlcs(
fwdInfo := chanIterator.ForwardingInstructions()
switch fwdInfo.NextHop {
case exitHop:
if l.cfg.DebugHTLC && l.cfg.HodlHTLC {
log.Warnf("hodl HTLC mode enabled, " +
"will not attempt to settle " +
@ -1538,7 +1857,8 @@ func (l *channelLink) processLockedInHtlcs(
pd.Timeout, heightNow)
failure := lnwire.FailFinalIncorrectCltvExpiry{}
l.sendHTLCError(pd.HtlcIndex, &failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex, &failure,
obfuscator, pd.SourceRef)
needUpdate = true
continue
}
@ -1553,23 +1873,34 @@ func (l *channelLink) processLockedInHtlcs(
log.Errorf("unable to query invoice registry: "+
" %v", err)
failure := lnwire.FailUnknownPaymentHash{}
l.sendHTLCError(pd.HtlcIndex, failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex, failure,
obfuscator, pd.SourceRef)
needUpdate = true
continue
}
// If this invoice has already been settled,
// then we'll reject it as we don't allow an
// invoice to be paid twice.
if invoice.Terms.Settled == true {
log.Warnf("Rejecting duplicate "+
// If the invoice is already settled, we choose
// to accept the payment to simplify failure
// recovery.
//
// NOTE: Though our recovery and forwarding logic is
// predominately batched, settling invoices
// happens iteratively. We may reject one of of
// two payments for the same rhash at first, but
// then restart and reject both after seeing
// that the invoice has been settled. Without
// any record of which one settles first, it is
// ambiguous as to which one actually settled
// the invoice. Thus, by accepting all payments,
// we eliminate the race condition that can lead
// to this inconsistency.
//
// TODO(conner): track ownership of settlements
// to properly recover from failures? or add
// batch invoice settlement
if invoice.Terms.Settled {
log.Warnf("Accepting duplicate "+
"payment for hash=%x", pd.RHash[:])
failure := lnwire.FailUnknownPaymentHash{}
l.sendHTLCError(
pd.HtlcIndex, failure, obfuscator,
)
needUpdate = true
continue
}
// If we're not currently in debug mode, and
@ -1591,7 +1922,8 @@ func (l *channelLink) processLockedInHtlcs(
"amount: expected %v, received %v",
invoice.Terms.Value, pd.Amount)
failure := lnwire.FailIncorrectPaymentAmount{}
l.sendHTLCError(pd.HtlcIndex, failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex, failure,
obfuscator, pd.SourceRef)
needUpdate = true
continue
}
@ -1618,7 +1950,8 @@ func (l *channelLink) processLockedInHtlcs(
fwdInfo.AmountToForward)
failure := lnwire.FailIncorrectPaymentAmount{}
l.sendHTLCError(pd.HtlcIndex, failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex, failure,
obfuscator, pd.SourceRef)
needUpdate = true
continue
}
@ -1640,7 +1973,9 @@ func (l *channelLink) processLockedInHtlcs(
failure := lnwire.NewFinalIncorrectCltvExpiry(
fwdInfo.OutgoingCTLV,
)
l.sendHTLCError(pd.HtlcIndex, failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex,
failure, obfuscator,
pd.SourceRef)
needUpdate = true
continue
case pd.Timeout != fwdInfo.OutgoingCTLV:
@ -1652,28 +1987,33 @@ func (l *channelLink) processLockedInHtlcs(
failure := lnwire.NewFinalIncorrectCltvExpiry(
fwdInfo.OutgoingCTLV,
)
l.sendHTLCError(pd.HtlcIndex, failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex,
failure, obfuscator,
pd.SourceRef)
needUpdate = true
continue
}
}
preimage := invoice.Terms.PaymentPreimage
err = l.channel.SettleHTLC(preimage, pd.HtlcIndex, nil, nil, nil)
err = l.channel.SettleHTLC(preimage,
pd.HtlcIndex, pd.SourceRef, nil, nil)
if err != nil {
l.fail("unable to settle htlc: %v", err)
return nil
return false
}
// Notify the invoiceRegistry of the invoices
// we just settled with this latest commitment
// Notify the invoiceRegistry of the invoices we
// just settled with this latest commitment
// update.
err = l.cfg.Registry.SettleInvoice(invoiceHash)
if err != nil {
l.fail("unable to settle invoice: %v", err)
return nil
return false
}
l.infof("Settling %x as exit hop", pd.RHash)
// HTLC was successfully settled locally send
// notification about it remote peer.
l.cfg.Peer.SendMessage(&lnwire.UpdateFulfillHTLC{
@ -1688,6 +2028,53 @@ func (l *channelLink) processLockedInHtlcs(
// constraints have been properly met by by this
// incoming HTLC.
default:
switch fwdPkg.State {
case channeldb.FwdStateProcessed:
if !fwdPkg.FwdFilter.Contains(idx) {
// This add was not forwarded on
// the previous processing
// phase, run it through our
// validation pipeline to
// reproduce an error. This may
// trigger a different error due
// to expiring timelocks, but we
// expect that an error will be
// reproduced.
break
}
addMsg := &lnwire.UpdateAddHTLC{
Expiry: fwdInfo.OutgoingCTLV,
Amount: fwdInfo.AmountToForward,
PaymentHash: pd.RHash,
}
// Finally, we'll encode the onion packet for
// the _next_ hop using the hop iterator
// decoded for the current hop.
buf := bytes.NewBuffer(addMsg.OnionBlob[0:0])
// We know this cannot fail, as this ADD
// was marked forwarded in a previous
// round of processing.
chanIterator.EncodeNextHop(buf)
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef,
incomingAmount: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
}
switchPackets = append(switchPackets,
updatePacket)
continue
}
// We want to avoid forwarding an HTLC which
// will expire in the near future, so we'll
// reject an HTLC if its expiration time is too
@ -1707,7 +2094,8 @@ func (l *channelLink) processLockedInHtlcs(
failure = lnwire.NewExpiryTooSoon(*update)
}
l.sendHTLCError(pd.HtlcIndex, failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex, failure,
obfuscator, pd.SourceRef)
needUpdate = true
continue
}
@ -1734,7 +2122,8 @@ func (l *channelLink) processLockedInHtlcs(
pd.Amount, *update)
}
l.sendHTLCError(pd.HtlcIndex, failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex, failure,
obfuscator, pd.SourceRef)
needUpdate = true
continue
}
@ -1779,7 +2168,8 @@ func (l *channelLink) processLockedInHtlcs(
*update)
}
l.sendHTLCError(pd.HtlcIndex, failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex, failure,
obfuscator, pd.SourceRef)
needUpdate = true
continue
}
@ -1806,12 +2196,13 @@ func (l *channelLink) processLockedInHtlcs(
if err != nil {
l.fail("unable to create channel update "+
"while handling the error: %v", err)
return nil
return false
}
failure := lnwire.NewIncorrectCltvExpiry(
pd.Timeout, *update)
l.sendHTLCError(pd.HtlcIndex, failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex, failure,
obfuscator, pd.SourceRef)
needUpdate = true
continue
}
@ -1838,42 +2229,107 @@ func (l *channelLink) processLockedInHtlcs(
"remaining route %v", err)
failure := lnwire.NewTemporaryChannelFailure(nil)
l.sendHTLCError(pd.HtlcIndex, failure, obfuscator)
l.sendHTLCError(pd.HtlcIndex, failure,
obfuscator, pd.SourceRef)
needUpdate = true
continue
}
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
incomingHtlcAmt: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
// Now that this add has been reprocessed, only
// append it to our list of packets to forward
// to the switch this is the first time
// processing the add. If the fwd pkg has
// already been processed, then we entered the
// above section to recreate a previous error.
// If the packet had previously been forwarded,
// it would have been added to switchPackets at
// the top of this section.
if fwdPkg.State == channeldb.FwdStateLockedIn {
updatePacket := &htlcPacket{
incomingChanID: l.ShortChanID(),
incomingHTLCID: pd.HtlcIndex,
outgoingChanID: fwdInfo.NextHop,
sourceRef: pd.SourceRef,
incomingAmount: pd.Amount,
amount: addMsg.Amount,
htlc: addMsg,
obfuscator: obfuscator,
}
fwdPkg.FwdFilter.Set(idx)
switchPackets = append(switchPackets,
updatePacket)
}
packetsToForward = append(packetsToForward, updatePacket)
}
}
}
if needUpdate {
// With all the settle/cancel updates added to the local and
// remote HTLC logs, initiate a state transition by updating
// the remote commitment chain.
if err := l.updateCommitTx(); err != nil {
l.fail("unable to update commitment: %v", err)
return nil
// Commit the htlcs we are intending to forward if this package has not
// been fully processed.
if fwdPkg.State == channeldb.FwdStateLockedIn {
err := l.channel.SetFwdFilter(fwdPkg.Height, fwdPkg.FwdFilter)
if err != nil {
l.fail("unable to set fwd filter: %v", err)
return false
}
}
return packetsToForward
if len(switchPackets) == 0 {
return needUpdate
}
l.debugf("forwarding %d packets to switch", len(switchPackets))
go l.forwardBatch(switchPackets...)
return needUpdate
}
// forwardBatch forwards the given htlcPackets to the switch, and waits on the
// err chan for the individual responses. This method is intended to be spawned
// as a goroutine so the responses can be handled in the background.
func (l *channelLink) forwardBatch(packets ...*htlcPacket) {
// Don't forward packets for which we already have a response in our
// mailbox. This could happen if a packet fails and is buffered in the
// mailbox, and the incoming link flaps.
var filteredPkts = make([]*htlcPacket, 0, len(packets))
for _, pkt := range packets {
if l.mailBox.HasPacket(pkt.inKey()) {
continue
}
filteredPkts = append(filteredPkts, pkt)
}
errChan := l.cfg.ForwardPackets(filteredPkts...)
l.handleBatchFwdErrs(errChan)
}
// handleBatchFwdErrs waits on the given errChan until it is closed, logging the
// errors returned from any unsuccessful forwarding attempts.
func (l *channelLink) handleBatchFwdErrs(errChan chan error) {
for {
err, ok := <-errChan
if !ok {
// Err chan has been drained or switch is shutting down.
// Either way, return.
return
}
if err == nil {
continue
}
l.errorf("unhandled error while forwarding htlc packet over "+
"htlcswitch: %v", err)
}
}
// sendHTLCError functions cancels HTLC and send cancel message back to the
// peer from which HTLC was received.
func (l *channelLink) sendHTLCError(htlcIndex uint64,
failure lnwire.FailureMessage, e ErrorEncrypter) {
failure lnwire.FailureMessage, e ErrorEncrypter,
sourceRef *channeldb.AddRef) {
reason, err := e.EncryptFirstHop(failure)
if err != nil {
@ -1881,7 +2337,7 @@ func (l *channelLink) sendHTLCError(htlcIndex uint64,
return
}
err = l.channel.FailHTLC(htlcIndex, reason, nil, nil, nil)
err = l.channel.FailHTLC(htlcIndex, reason, sourceRef, nil, nil)
if err != nil {
log.Errorf("unable cancel htlc: %v", err)
return
@ -1897,10 +2353,10 @@ func (l *channelLink) sendHTLCError(htlcIndex uint64,
// sendMalformedHTLCError helper function which sends the malformed HTLC update
// to the payment sender.
func (l *channelLink) sendMalformedHTLCError(htlcIndex uint64,
code lnwire.FailCode, onionBlob []byte) {
code lnwire.FailCode, onionBlob []byte, sourceRef *channeldb.AddRef) {
shaOnionBlob := sha256.Sum256(onionBlob)
err := l.channel.MalformedFailHTLC(htlcIndex, code, shaOnionBlob, nil)
err := l.channel.MalformedFailHTLC(htlcIndex, code, shaOnionBlob, sourceRef)
if err != nil {
log.Errorf("unable cancel htlc: %v", err)
return
@ -1921,3 +2377,33 @@ func (l *channelLink) fail(format string, a ...interface{}) {
log.Error(reason)
go l.cfg.Peer.Disconnect(reason)
}
// infof prefixes the channel's identifier before printing to info log.
func (l *channelLink) infof(format string, a ...interface{}) {
msg := fmt.Sprintf(format, a...)
log.Infof("ChannelLink(%s) %s", l.ShortChanID(), msg)
}
// debugf prefixes the channel's identifier before printing to debug log.
func (l *channelLink) debugf(format string, a ...interface{}) {
msg := fmt.Sprintf(format, a...)
log.Debugf("ChannelLink(%s) %s", l.ShortChanID(), msg)
}
// warnf prefixes the channel's identifier before printing to warn log.
func (l *channelLink) warnf(format string, a ...interface{}) {
msg := fmt.Sprintf(format, a...)
log.Warnf("ChannelLink(%s) %s", l.ShortChanID(), msg)
}
// errorf prefixes the channel's identifier before printing to error log.
func (l *channelLink) errorf(format string, a ...interface{}) {
msg := fmt.Sprintf(format, a...)
log.Errorf("ChannelLink(%s) %s", l.ShortChanID(), msg)
}
// tracef prefixes the channel's identifier before printing to trace log.
func (l *channelLink) tracef(format string, a ...interface{}) {
msg := fmt.Sprintf(format, a...)
log.Tracef("ChannelLink(%s) %s", l.ShortChanID(), msg)
}

@ -2,7 +2,10 @@ package htlcswitch
import (
"bytes"
"crypto/rand"
"encoding/binary"
"fmt"
"io"
"runtime"
"strings"
"sync"
@ -837,7 +840,7 @@ func TestUpdateForwardingPolicy(t *testing.T) {
ferr, ok := err.(*ForwardingError)
if !ok {
t.Fatalf("expected a ForwardingError, instead got: %T", err)
t.Fatalf("expected a ForwardingError, instead got (%T): %v", err, err)
}
switch ferr.FailureMessage.(type) {
case *lnwire.FailFeeInsufficient:
@ -1050,7 +1053,11 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) {
htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight,
n.firstBobChannelLink, n.carolChannelLink)
davePub := newMockServer(t, "dave").PubKey()
daveServer, err := newMockServer(t, "dave", nil)
if err != nil {
t.Fatalf("unable to init dave's server: %v", err)
}
davePub := daveServer.PubKey()
receiver := n.bobServer
rhash, err := n.makePayment(n.aliceServer, n.bobServer, davePub, hops,
amount, htlcAmt, totalTimelock).Wait(30 * time.Second)
@ -1412,7 +1419,14 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
},
}
chanID := lnwire.NewShortChanIDFromInt(4)
var chanIDBytes [8]byte
if _, err := io.ReadFull(rand.Reader, chanIDBytes[:]); err != nil {
return nil, nil, nil, nil, err
}
chanID := lnwire.NewShortChanIDFromInt(
binary.BigEndian.Uint64(chanIDBytes[:]))
aliceChannel, bobChannel, fCleanUp, _, err := createTestChannel(
alicePrivKey, bobPrivKey, chanAmt, chanAmt,
chanReserve, chanReserve, chanID,
@ -1423,8 +1437,8 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
var (
invoiceRegistry = newMockRegistry()
decoder = &mockIteratorDecoder{}
obfuscator = newMockObfuscator()
decoder = newMockIteratorDecoder()
obfuscator = NewMockObfuscator()
alicePeer = &mockPeer{
sentMsgs: make(chan lnwire.Message, 2000),
quit: make(chan struct{}),
@ -1442,14 +1456,25 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
preimageMap: make(map[[32]byte][]byte),
}
aliceDb := aliceChannel.State().Db
aliceSwitch, err := New(Config{DB: aliceDb})
if err != nil {
return nil, nil, nil, nil, err
}
t := make(chan time.Time)
ticker := &mockTicker{t}
aliceCfg := ChannelLinkConfig{
FwrdingPolicy: globalPolicy,
Peer: alicePeer,
Switch: New(Config{}),
DecodeHopIterator: decoder.DecodeHopIterator,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (ErrorEncrypter, lnwire.FailCode) {
FwrdingPolicy: globalPolicy,
Peer: alicePeer,
Switch: aliceSwitch,
Circuits: aliceSwitch.CircuitModifier(),
ForwardPackets: aliceSwitch.ForwardPackets,
DecodeHopIterator: decoder.DecodeHopIterator,
DecodeHopIterators: decoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (
ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone
},
GetLastChannelUpdate: mockGetChanUpdateMessage,
@ -1457,10 +1482,11 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
UpdateContractSignals: func(*contractcourt.ContractSignals) error {
return nil
},
Registry: invoiceRegistry,
ChainEvents: &contractcourt.ChainEventSubscription{},
BlockEpochs: globalEpoch,
BatchTicker: ticker,
Registry: invoiceRegistry,
ChainEvents: &contractcourt.ChainEventSubscription{},
BlockEpochs: globalEpoch,
BatchTicker: ticker,
FwdPkgGCTicker: NewBatchTicker(time.NewTicker(5 * time.Second)),
// Make the BatchSize large enough to not
// trigger commit update automatically during tests.
BatchSize: 10000,
@ -1468,6 +1494,9 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
const startingHeight = 100
aliceLink := NewChannelLink(aliceCfg, aliceChannel, startingHeight)
mailbox := newMemoryMailBox()
mailbox.Start()
aliceLink.AttachMailBox(mailbox)
if err := aliceLink.Start(); err != nil {
return nil, nil, nil, nil, err
}
@ -1659,25 +1688,27 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
// We'll start the test by creating a single instance of
const chanAmt = btcutil.SatoshiPerBitcoin * 5
link, bobChannel, tmr, cleanUp, err := newSingleLinkTestHarness(chanAmt, 0)
aliceLink, bobChannel, tmr, cleanUp, err := newSingleLinkTestHarness(chanAmt, 0)
if err != nil {
t.Fatalf("unable to create link: %v", err)
}
defer cleanUp()
var (
carolChanID = lnwire.NewShortChanIDFromInt(3)
mockBlob [lnwire.OnionPacketSize]byte
aliceLink = link.(*channelLink)
aliceChannel = aliceLink.channel
defaultCommitFee = aliceChannel.StateSnapshot().CommitFee
coreChan = aliceLink.(*channelLink).channel
coreLink = aliceLink.(*channelLink)
defaultCommitFee = coreChan.StateSnapshot().CommitFee
aliceStartingBandwidth = aliceLink.Bandwidth()
aliceMsgs = aliceLink.cfg.Peer.(*mockPeer).sentMsgs
aliceMsgs = coreLink.cfg.Peer.(*mockPeer).sentMsgs
)
// We put Alice into HodlHTLC mode, such that she won't settle
// incoming HTLCs automatically.
aliceLink.cfg.HodlHTLC = true
aliceLink.cfg.DebugHTLC = true
coreLink.cfg.HodlHTLC = true
coreLink.cfg.DebugHTLC = true
estimator := &lnwallet.StaticFeeEstimator{
FeeRate: 24,
@ -1705,9 +1736,22 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
t.Fatalf("unable to create payment: %v", err)
}
addPkt := htlcPacket{
htlc: htlc,
htlc: htlc,
incomingChanID: sourceHop,
incomingHTLCID: 0,
obfuscator: NewMockObfuscator(),
}
circuit := makePaymentCircuit(&htlc.PaymentHash, &addPkt)
_, err = coreLink.cfg.Switch.commitCircuits(&circuit)
if err != nil {
t.Fatalf("unable to commit circuit: %v", err)
}
addPkt.circuit = &circuit
if err := aliceLink.HandleSwitchPacket(&addPkt); err != nil {
t.Fatalf("unable to handle switch packet: %v", err)
}
aliceLink.HandleSwitchPacket(&addPkt)
time.Sleep(time.Millisecond * 500)
// The resulting bandwidth should reflect that Alice is paying the
@ -1733,10 +1777,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
}
// Lock in the HTLC.
if err := updateState(tmr, aliceLink, bobChannel, true); err != nil {
if err := updateState(tmr, coreLink, bobChannel, true); err != nil {
t.Fatalf("unable to update state: %v", err)
}
// Locking in the HTLC should not change Alice's bandwidth.
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth-htlcAmt-htlcFee)
@ -1748,7 +1791,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
t.Fatalf("unable to settle htlc: %v", err)
}
htlcSettle := &lnwire.UpdateFulfillHTLC{
ID: bobIndex,
ID: 0,
PaymentPreimage: invoice.Terms.PaymentPreimage,
}
aliceLink.HandleChannelUpdate(htlcSettle)
@ -1759,7 +1802,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth-htlcAmt-htlcFee)
// Lock in the settle.
if err := updateState(tmr, aliceLink, bobChannel, false); err != nil {
if err := updateState(tmr, coreLink, bobChannel, false); err != nil {
t.Fatalf("unable to update state: %v", err)
}
@ -1773,9 +1816,22 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
t.Fatalf("unable to create payment: %v", err)
}
addPkt = htlcPacket{
htlc: htlc,
htlc: htlc,
incomingChanID: sourceHop,
incomingHTLCID: 1,
obfuscator: NewMockObfuscator(),
}
circuit = makePaymentCircuit(&htlc.PaymentHash, &addPkt)
_, err = coreLink.cfg.Switch.commitCircuits(&circuit)
if err != nil {
t.Fatalf("unable to commit circuit: %v", err)
}
addPkt.circuit = &circuit
if err := aliceLink.HandleSwitchPacket(&addPkt); err != nil {
t.Fatalf("unable to handle switch packet: %v", err)
}
aliceLink.HandleSwitchPacket(&addPkt)
time.Sleep(time.Millisecond * 500)
// Again, Alice's bandwidth decreases by htlcAmt+htlcFee.
@ -1787,6 +1843,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
case <-time.After(2 * time.Second):
t.Fatalf("did not receive message")
}
addHtlc, ok = msg.(*lnwire.UpdateAddHTLC)
if !ok {
t.Fatalf("expected UpdateAddHTLC, got %T", msg)
@ -1798,7 +1855,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
}
// Lock in the HTLC, which should not affect the bandwidth.
if err := updateState(tmr, aliceLink, bobChannel, true); err != nil {
if err := updateState(tmr, coreLink, bobChannel, true); err != nil {
t.Fatalf("unable to update state: %v", err)
}
@ -1812,9 +1869,10 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
t.Fatalf("unable to fail htlc: %v", err)
}
failMsg := &lnwire.UpdateFailHTLC{
ID: bobIndex,
ID: 1,
Reason: lnwire.OpaqueReason([]byte("nop")),
}
aliceLink.HandleChannelUpdate(failMsg)
time.Sleep(time.Millisecond * 500)
@ -1822,7 +1880,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth-htlcAmt*2-htlcFee)
// Lock in the Fail.
if err := updateState(tmr, aliceLink, bobChannel, false); err != nil {
if err := updateState(tmr, coreLink, bobChannel, false); err != nil {
t.Fatalf("unable to update state: %v", err)
}
@ -1834,7 +1892,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
// remain unchanged (but Alice will need to pay the fee for the extra
// HTLC).
htlcAmt, totalTimelock, hops := generateHops(htlcAmt, testStartingHeight,
aliceLink)
coreLink)
blob, err := generateRoute(hops...)
if err != nil {
t.Fatalf("unable to gen route: %v", err)
@ -1847,11 +1905,12 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
// We must add the invoice to the registry, such that Alice expects
// this payment.
err = aliceLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(*invoice)
err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(*invoice)
if err != nil {
t.Fatalf("unable to add invoice to registry: %v", err)
}
htlc.ID = 0
bobIndex, err = bobChannel.AddHTLC(htlc, nil)
if err != nil {
t.Fatalf("unable to add htlc: %v", err)
@ -1862,58 +1921,84 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth-htlcAmt)
// Lock in the HTLC.
if err := updateState(tmr, aliceLink, bobChannel, false); err != nil {
if err := updateState(tmr, coreLink, bobChannel, false); err != nil {
t.Fatalf("unable to update state: %v", err)
}
// Since Bob is adding this HTLC, Alice only needs to pay the fee.
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth-htlcAmt-htlcFee)
time.Sleep(time.Millisecond * 500)
addPkt = htlcPacket{
htlc: htlc,
incomingChanID: aliceLink.ShortChanID(),
incomingHTLCID: 0,
obfuscator: NewMockObfuscator(),
}
circuit = makePaymentCircuit(&htlc.PaymentHash, &addPkt)
_, err = coreLink.cfg.Switch.commitCircuits(&circuit)
if err != nil {
t.Fatalf("unable to commit circuit: %v", err)
}
addPkt.outgoingChanID = carolChanID
addPkt.outgoingHTLCID = 0
err = coreLink.cfg.Switch.openCircuits(addPkt.keystone())
if err != nil {
t.Fatalf("unable to set keystone: %v", err)
}
// Next, we'll settle the HTLC with our knowledge of the pre-image that
// we eventually learn (simulating a multi-hop payment). The bandwidth
// of the channel should now be re-balanced to the starting point.
settlePkt := htlcPacket{
incomingChanID: aliceLink.ShortChanID(),
incomingHTLCID: 0,
circuit: &circuit,
outgoingChanID: addPkt.outgoingChanID,
outgoingHTLCID: addPkt.outgoingHTLCID,
htlc: &lnwire.UpdateFulfillHTLC{
ID: bobIndex,
ID: 0,
PaymentPreimage: invoice.Terms.PaymentPreimage,
},
obfuscator: NewMockObfuscator(),
}
aliceLink.HandleSwitchPacket(&settlePkt)
if err := aliceLink.HandleSwitchPacket(&settlePkt); err != nil {
t.Fatalf("unable to handle switch packet: %v", err)
}
time.Sleep(time.Millisecond * 500)
// Settling this HTLC gives Alice all her original bandwidth back.
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth)
// Alice wil send the Settle to Bob.
select {
case msg = <-aliceMsgs:
case <-time.After(2 * time.Second):
t.Fatalf("did not receive message")
}
settleHtlc, ok := msg.(*lnwire.UpdateFulfillHTLC)
settleMsg, ok := msg.(*lnwire.UpdateFulfillHTLC)
if !ok {
t.Fatalf("expected UpdateFulfillHTLC, got %T", msg)
}
pre := settleHtlc.PaymentPreimage
idx := settleHtlc.ID
err = bobChannel.ReceiveHTLCSettle(pre, idx)
err = bobChannel.ReceiveHTLCSettle(settleMsg.PaymentPreimage, settleMsg.ID)
if err != nil {
t.Fatalf("unable to receive settle: %v", err)
t.Fatalf("failed receiving fail htlc: %v", err)
}
// After a settle the link should do a state transition automatically,
// so we don't have to trigger it.
if err := handleStateUpdate(aliceLink, bobChannel); err != nil {
// After failing an HTLC, the link will automatically trigger
// a state update.
if err := handleStateUpdate(coreLink, bobChannel); err != nil {
t.Fatalf("unable to update state: %v", err)
}
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth)
// Finally, we'll test the scenario of failing an HTLC received from the
// Finally, we'll test the scenario of failing an HTLC received by the
// remote node. This should result in no perceived bandwidth changes.
htlcAmt, totalTimelock, hops = generateHops(htlcAmt, testStartingHeight,
aliceLink)
coreLink)
blob, err = generateRoute(hops...)
if err != nil {
t.Fatalf("unable to gen route: %v", err)
@ -1922,7 +2007,8 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
if err != nil {
t.Fatalf("unable to create payment: %v", err)
}
if err := aliceLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(*invoice); err != nil {
err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(*invoice)
if err != nil {
t.Fatalf("unable to add invoice to registry: %v", err)
}
@ -1940,21 +2026,49 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
// No changes before the HTLC is locked in.
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth)
if err := updateState(tmr, aliceLink, bobChannel, false); err != nil {
if err := updateState(tmr, coreLink, bobChannel, false); err != nil {
t.Fatalf("unable to update state: %v", err)
}
// After lock-in, Alice will have to pay the htlc fee.
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth-htlcFee)
// Now fail this HTLC.
failPkt := htlcPacket{
incomingHTLCID: bobIndex,
htlc: &lnwire.UpdateFailHTLC{
ID: bobIndex,
},
addPkt = htlcPacket{
htlc: htlc,
incomingChanID: aliceLink.ShortChanID(),
incomingHTLCID: 1,
obfuscator: NewMockObfuscator(),
}
circuit = makePaymentCircuit(&htlc.PaymentHash, &addPkt)
_, err = coreLink.cfg.Switch.commitCircuits(&circuit)
if err != nil {
t.Fatalf("unable to commit circuit: %v", err)
}
addPkt.outgoingChanID = carolChanID
addPkt.outgoingHTLCID = 1
err = coreLink.cfg.Switch.openCircuits(addPkt.keystone())
if err != nil {
t.Fatalf("unable to set keystone: %v", err)
}
failPkt := htlcPacket{
incomingChanID: aliceLink.ShortChanID(),
incomingHTLCID: 1,
circuit: &circuit,
outgoingChanID: addPkt.outgoingChanID,
outgoingHTLCID: addPkt.outgoingHTLCID,
htlc: &lnwire.UpdateFailHTLC{
ID: 1,
},
obfuscator: NewMockObfuscator(),
}
if err := aliceLink.HandleSwitchPacket(&failPkt); err != nil {
t.Fatalf("unable to handle switch packet: %v", err)
}
aliceLink.HandleSwitchPacket(&failPkt)
time.Sleep(time.Millisecond * 500)
// Alice should get all her bandwidth back.
@ -1977,7 +2091,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
// After failing an HTLC, the link will automatically trigger
// a state update.
if err := handleStateUpdate(aliceLink, bobChannel); err != nil {
if err := handleStateUpdate(coreLink, bobChannel); err != nil {
t.Fatalf("unable to update state: %v", err)
}
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth)
@ -2015,20 +2129,27 @@ func TestChannelLinkBandwidthConsistencyOverflow(t *testing.T) {
}
feePerKw := feeRate.FeePerKWeight()
// The starting bandwidth of the channel should be exactly the amount
// that we created the channel between her and Bob.
expectedBandwidth := lnwire.NewMSatFromSatoshis(chanAmt - defaultCommitFee)
assertLinkBandwidth(t, aliceLink, expectedBandwidth)
addLinkHTLC := func(amt lnwire.MilliSatoshi) [32]byte {
var htlcID uint64
addLinkHTLC := func(id uint64, amt lnwire.MilliSatoshi) [32]byte {
invoice, htlc, err := generatePayment(amt, amt, 5, mockBlob)
if err != nil {
t.Fatalf("unable to create payment: %v", err)
}
aliceLink.HandleSwitchPacket(&htlcPacket{
htlc: htlc,
amount: amt,
})
addPkt := &htlcPacket{
htlc: htlc,
incomingHTLCID: id,
amount: amt,
obfuscator: NewMockObfuscator(),
}
circuit := makePaymentCircuit(&htlc.PaymentHash, addPkt)
_, err = coreLink.cfg.Switch.commitCircuits(&circuit)
if err != nil {
t.Fatalf("unable to commit circuit: %v", err)
}
addPkt.circuit = &circuit
aliceLink.HandleSwitchPacket(addPkt)
return invoice.Terms.PaymentPreimage
}
@ -2040,10 +2161,11 @@ func TestChannelLinkBandwidthConsistencyOverflow(t *testing.T) {
const numHTLCs = lnwallet.MaxHTLCNumber / 2
var preImages [][32]byte
for i := 0; i < numHTLCs; i++ {
preImage := addLinkHTLC(htlcAmt)
preImage := addLinkHTLC(htlcID, htlcAmt)
preImages = append(preImages, preImage)
totalHtlcAmt += htlcAmt
htlcID++
}
// The HTLCs should all be sent to the remote.
@ -2051,8 +2173,8 @@ func TestChannelLinkBandwidthConsistencyOverflow(t *testing.T) {
for i := 0; i < numHTLCs; i++ {
select {
case msg = <-aliceMsgs:
case <-time.After(2 * time.Second):
t.Fatalf("did not receive message")
case <-time.After(5 * time.Second):
t.Fatalf("did not receive message %d", i)
}
addHtlc, ok := msg.(*lnwire.UpdateAddHTLC)
@ -2078,7 +2200,7 @@ func TestChannelLinkBandwidthConsistencyOverflow(t *testing.T) {
htlcFee := lnwire.NewMSatFromSatoshis(
feePerKw.FeeForWeight(commitWeight),
)
expectedBandwidth = aliceStartingBandwidth - totalHtlcAmt - htlcFee
expectedBandwidth := aliceStartingBandwidth - totalHtlcAmt - htlcFee
expectedBandwidth += lnwire.NewMSatFromSatoshis(defaultCommitFee)
assertLinkBandwidth(t, aliceLink, expectedBandwidth)
@ -2094,10 +2216,11 @@ func TestChannelLinkBandwidthConsistencyOverflow(t *testing.T) {
// bandwidth accounting is done properly.
const numOverFlowHTLCs = 20
for i := 0; i < numOverFlowHTLCs; i++ {
preImage := addLinkHTLC(htlcAmt)
preImage := addLinkHTLC(htlcID, htlcAmt)
preImages = append(preImages, preImage)
totalHtlcAmt += htlcAmt
htlcID++
}
// No messages should be sent to the remote at this point.
@ -2245,10 +2368,18 @@ func TestChannelLinkBandwidthChanReserve(t *testing.T) {
if err != nil {
t.Fatalf("unable to create payment: %v", err)
}
addPkt := htlcPacket{
htlc: htlc,
addPkt := &htlcPacket{
htlc: htlc,
obfuscator: NewMockObfuscator(),
}
aliceLink.HandleSwitchPacket(&addPkt)
circuit := makePaymentCircuit(&htlc.PaymentHash, addPkt)
_, err = coreLink.cfg.Switch.commitCircuits(&circuit)
if err != nil {
t.Fatalf("unable to commit circuit: %v", err)
}
aliceLink.HandleSwitchPacket(addPkt)
time.Sleep(time.Millisecond * 100)
assertLinkBandwidth(t, aliceLink, aliceStartingBandwidth-htlcAmt-htlcFee)
@ -2834,11 +2965,11 @@ func TestChannelLinkUpdateCommitFee(t *testing.T) {
}
}
// TestChannelLinkRejectDuplicatePayment tests that if a link receives an
// incoming HTLC for a payment we have already settled, then it rejects the
// HTLC. We do this as we want to enforce the fact that invoices are only to be
// used _once.
func TestChannelLinkRejectDuplicatePayment(t *testing.T) {
// TestChannelLinkAcceptDuplicatePayment tests that if a link receives an
// incoming HTLC for a payment we have already settled, then it accepts the
// HTLC. We do this to simplify the processing of settles after restarts or
// failures, reducing ambiguity when a batch is only partially processed.
func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
t.Parallel()
// First, we'll create our traditional three hop network. We'll only be
@ -2891,8 +3022,8 @@ func TestChannelLinkRejectDuplicatePayment(t *testing.T) {
// as it's a duplicate request.
_, err = n.aliceServer.htlcSwitch.SendHTLC(n.bobServer.PubKey(), htlc,
newMockDeobfuscator())
if err.Error() != lnwire.CodeUnknownPaymentHash.String() {
t.Fatal("error haven't been received")
if err != nil {
t.Fatalf("error shouldn't have been received got: %v", err)
}
}

@ -1,23 +1,40 @@
package htlcswitch
import (
"container/list"
"errors"
"sync"
"sync/atomic"
"time"
"github.com/lightningnetwork/lnd/lnwire"
)
// mailBox is an interface which represents a concurrent-safe, in-order
// ErrMailBoxShuttingDown is returned when the mailbox is interrupted by a
// shutdown request.
var ErrMailBoxShuttingDown = errors.New("mailbox is shutting down")
// MailBox is an interface which represents a concurrent-safe, in-order
// delivery queue for messages from the network and also from the main switch.
// This struct servers as a buffer between incoming messages, and messages to
// the handled by the link. Each of the mutating methods within this interface
// should be implemented in a non-blocking manner.
type mailBox interface {
type MailBox interface {
// AddMessage appends a new message to the end of the message queue.
AddMessage(msg lnwire.Message) error
// AddPacket appends a new message to the end of the packet queue.
AddPacket(pkt *htlcPacket) error
// HasPacket queries the packets for a circuit key, this is used to drop
// packets bound for the switch that already have a queued response.
HasPacket(CircuitKey) bool
// AckPacket removes a packet from the mailboxes in-memory replay
// buffer. This will prevent a packet from being delivered after a link
// restarts if the switch has remained online.
AckPacket(CircuitKey) error
// MessageOutBox returns a channel that any new messages ready for
// delivery will be sent on.
MessageOutBox() chan lnwire.Message
@ -26,6 +43,12 @@ type mailBox interface {
// delivery will be sent on.
PacketOutBox() chan *htlcPacket
// Clears any pending wire messages from the inbox.
ResetMessages() error
// Reset the packet head to point at the first element in the list.
ResetPackets() error
// Start starts the mailbox and any goroutines it needs to operate
// properly.
Start() error
@ -34,20 +57,28 @@ type mailBox interface {
Stop() error
}
// memoryMailBox is an implementation of the mailBox struct backed by purely
// memoryMailBox is an implementation of the MailBox struct backed by purely
// in-memory queues.
type memoryMailBox struct {
wireMessages []lnwire.Message
started uint32
stopped uint32
wireMessages *list.List
wireHead *list.Element
wireMtx sync.Mutex
wireCond *sync.Cond
messageOutbox chan lnwire.Message
msgReset chan chan struct{}
htlcPkts []*htlcPacket
htlcPkts *list.List
pktIndex map[CircuitKey]*list.Element
pktHead *list.Element
pktMtx sync.Mutex
pktCond *sync.Cond
pktOutbox chan *htlcPacket
pktReset chan chan struct{}
wg sync.WaitGroup
quit chan struct{}
@ -56,9 +87,14 @@ type memoryMailBox struct {
// newMemoryMailBox creates a new instance of the memoryMailBox.
func newMemoryMailBox() *memoryMailBox {
box := &memoryMailBox{
quit: make(chan struct{}),
wireMessages: list.New(),
htlcPkts: list.New(),
messageOutbox: make(chan lnwire.Message),
pktOutbox: make(chan *htlcPacket),
msgReset: make(chan chan struct{}, 1),
pktReset: make(chan chan struct{}, 1),
pktIndex: make(map[CircuitKey]*list.Element),
quit: make(chan struct{}),
}
box.wireCond = sync.NewCond(&box.wireMtx)
box.pktCond = sync.NewCond(&box.pktMtx)
@ -66,12 +102,12 @@ func newMemoryMailBox() *memoryMailBox {
return box
}
// A compile time assertion to ensure that memoryMailBox meets the mailBox
// A compile time assertion to ensure that memoryMailBox meets the MailBox
// interface.
var _ mailBox = (*memoryMailBox)(nil)
var _ MailBox = (*memoryMailBox)(nil)
// courierType is an enum that reflects the distinct types of messages a
// mailBox can handle. Each type will be placed in an isolated mail box and
// MailBox can handle. Each type will be placed in an isolated mail box and
// will have a dedicated goroutine for delivering the messages.
type courierType uint8
@ -85,8 +121,12 @@ const (
// Start starts the mailbox and any goroutines it needs to operate properly.
//
// NOTE: This method is part of the mailBox interface.
// NOTE: This method is part of the MailBox interface.
func (m *memoryMailBox) Start() error {
if !atomic.CompareAndSwapUint32(&m.started, 0, 1) {
return nil
}
m.wg.Add(2)
go m.mailCourier(wireCourier)
go m.mailCourier(pktCourier)
@ -94,10 +134,90 @@ func (m *memoryMailBox) Start() error {
return nil
}
// ResetMessages blocks until all buffered wire messages are cleared.
func (m *memoryMailBox) ResetMessages() error {
msgDone := make(chan struct{})
select {
case m.msgReset <- msgDone:
return m.signalUntilReset(wireCourier, msgDone)
case <-m.quit:
return ErrMailBoxShuttingDown
}
}
// ResetPackets blocks until the head of packets buffer is reset, causing the
// packets to be redelivered in order.
func (m *memoryMailBox) ResetPackets() error {
pktDone := make(chan struct{})
select {
case m.pktReset <- pktDone:
return m.signalUntilReset(pktCourier, pktDone)
case <-m.quit:
return ErrMailBoxShuttingDown
}
}
// signalUntilReset strobes the condition variable for the specified inbox type
// until receiving a response that the mailbox has processed a reset.
func (m *memoryMailBox) signalUntilReset(cType courierType,
done chan struct{}) error {
for {
switch cType {
case wireCourier:
m.wireCond.Signal()
case pktCourier:
m.pktCond.Signal()
}
select {
case <-time.After(time.Millisecond):
continue
case <-done:
return nil
case <-m.quit:
return ErrMailBoxShuttingDown
}
}
}
// AckPacket removes the packet identified by it's incoming circuit key from the
// queue of packets to be delivered.
//
// NOTE: It is safe to call this method multiple times for the same circuit key.
func (m *memoryMailBox) AckPacket(inKey CircuitKey) error {
m.pktCond.L.Lock()
entry, ok := m.pktIndex[inKey]
if !ok {
m.pktCond.L.Unlock()
return nil
}
m.htlcPkts.Remove(entry)
delete(m.pktIndex, inKey)
m.pktCond.L.Unlock()
return nil
}
// HasPacket queries the packets for a circuit key, this is used to drop packets
// bound for the switch that already have a queued response.
func (m *memoryMailBox) HasPacket(inKey CircuitKey) bool {
m.pktCond.L.Lock()
_, ok := m.pktIndex[inKey]
m.pktCond.L.Unlock()
return ok
}
// Stop signals the mailbox and its goroutines for a graceful shutdown.
//
// NOTE: This method is part of the mailBox interface.
// NOTE: This method is part of the MailBox interface.
func (m *memoryMailBox) Stop() error {
if !atomic.CompareAndSwapUint32(&m.stopped, 0, 1) {
return nil
}
close(m.quit)
m.wireCond.Signal()
@ -121,10 +241,13 @@ func (m *memoryMailBox) mailCourier(cType courierType) {
switch cType {
case wireCourier:
m.wireCond.L.Lock()
for len(m.wireMessages) == 0 {
for m.wireMessages.Front() == nil {
m.wireCond.Wait()
select {
case msgDone := <-m.msgReset:
m.wireMessages.Init()
close(msgDone)
case <-m.quit:
m.wireCond.L.Unlock()
return
@ -134,10 +257,13 @@ func (m *memoryMailBox) mailCourier(cType courierType) {
case pktCourier:
m.pktCond.L.Lock()
for len(m.htlcPkts) == 0 {
for m.pktHead == nil {
m.pktCond.Wait()
select {
case pktDone := <-m.pktReset:
m.pktHead = m.htlcPkts.Front()
close(pktDone)
case <-m.quit:
m.pktCond.L.Unlock()
return
@ -155,13 +281,11 @@ func (m *memoryMailBox) mailCourier(cType courierType) {
)
switch cType {
case wireCourier:
nextMsg = m.wireMessages[0]
m.wireMessages[0] = nil // Set to nil to prevent GC leak.
m.wireMessages = m.wireMessages[1:]
entry := m.wireMessages.Front()
nextMsg = m.wireMessages.Remove(entry).(lnwire.Message)
case pktCourier:
nextPkt = m.htlcPkts[0]
m.htlcPkts[0] = nil // Set to nil to prevent GC leak.
m.htlcPkts = m.htlcPkts[1:]
nextPkt = m.pktHead.Value.(*htlcPacket)
m.pktHead = m.pktHead.Next()
}
// Now that we're done with the condition, we can unlock it to
@ -173,13 +297,17 @@ func (m *memoryMailBox) mailCourier(cType courierType) {
m.pktCond.L.Unlock()
}
// With the next message obtained, we'll now select to attempt
// to deliver the message. If we receive a kill signal, then
// we'll bail out.
switch cType {
case wireCourier:
select {
case m.messageOutbox <- nextMsg:
case msgDone := <-m.msgReset:
m.wireCond.L.Lock()
m.wireMessages.Init()
m.wireCond.L.Unlock()
close(msgDone)
case <-m.quit:
return
}
@ -187,6 +315,11 @@ func (m *memoryMailBox) mailCourier(cType courierType) {
case pktCourier:
select {
case m.pktOutbox <- nextPkt:
case pktDone := <-m.pktReset:
m.pktCond.L.Lock()
m.pktHead = m.htlcPkts.Front()
m.pktCond.L.Unlock()
close(pktDone)
case <-m.quit:
return
}
@ -197,13 +330,13 @@ func (m *memoryMailBox) mailCourier(cType courierType) {
// AddMessage appends a new message to the end of the message queue.
//
// NOTE: This method is safe for concrete use and part of the mailBox
// NOTE: This method is safe for concrete use and part of the MailBox
// interface.
func (m *memoryMailBox) AddMessage(msg lnwire.Message) error {
// First, we'll lock the condition, and add the message to the end of
// the wire message inbox.
m.wireCond.L.Lock()
m.wireMessages = append(m.wireMessages, msg)
m.wireMessages.PushBack(msg)
m.wireCond.L.Unlock()
// With the message added, we signal to the mailCourier that there are
@ -215,13 +348,22 @@ func (m *memoryMailBox) AddMessage(msg lnwire.Message) error {
// AddPacket appends a new message to the end of the packet queue.
//
// NOTE: This method is safe for concrete use and part of the mailBox
// NOTE: This method is safe for concrete use and part of the MailBox
// interface.
func (m *memoryMailBox) AddPacket(pkt *htlcPacket) error {
// First, we'll lock the condition, and add the packet to the end of
// the htlc packet inbox.
m.pktCond.L.Lock()
m.htlcPkts = append(m.htlcPkts, pkt)
if _, ok := m.pktIndex[pkt.inKey()]; ok {
m.pktCond.L.Unlock()
return nil
}
entry := m.htlcPkts.PushBack(pkt)
m.pktIndex[pkt.inKey()] = entry
if m.pktHead == nil {
m.pktHead = entry
}
m.pktCond.L.Unlock()
// With the packet added, we signal to the mailCourier that there are
@ -234,7 +376,7 @@ func (m *memoryMailBox) AddPacket(pkt *htlcPacket) error {
// MessageOutBox returns a channel that any new messages ready for delivery
// will be sent on.
//
// NOTE: This method is part of the mailBox interface.
// NOTE: This method is part of the MailBox interface.
func (m *memoryMailBox) MessageOutBox() chan lnwire.Message {
return m.messageOutbox
}
@ -242,7 +384,7 @@ func (m *memoryMailBox) MessageOutBox() chan lnwire.Message {
// PacketOutBox returns a channel that any new packets ready for delivery will
// be sent on.
//
// NOTE: This method is part of the mailBox interface.
// NOTE: This method is part of the MailBox interface.
func (m *memoryMailBox) PacketOutBox() chan *htlcPacket {
return m.pktOutbox
}

@ -25,6 +25,7 @@ func TestMailBoxCouriers(t *testing.T) {
// We'll be adding 10 message of both types to the mailbox.
const numPackets = 10
const halfPackets = numPackets / 2
// We'll add a set of random packets to the mailbox.
sentPackets := make([]*htlcPacket, numPackets)
@ -96,4 +97,53 @@ func TestMailBoxCouriers(t *testing.T) {
t.Fatalf("recvd messages mismatched: expected %v, got %v",
spew.Sdump(sentMessages), spew.Sdump(recvdMessages))
}
// Now that we've received all of the intended msgs/pkts, ack back half
// of the packets.
for _, recvdPkt := range recvdPackets[:halfPackets] {
mailBox.AckPacket(recvdPkt.inKey())
}
// With the packets drained and partially acked, we reset the mailbox,
// simulating a link shutting down and then coming back up.
mailBox.ResetMessages()
mailBox.ResetPackets()
// Now, we'll use the same alternating strategy to read from our
// mailbox. All wire messages are dropped on startup, but any unacked
// packets will be replayed in the same order they were delivered
// initially.
recvdPackets2 := make([]*htlcPacket, 0, halfPackets)
for i := 0; i < 2*halfPackets; i++ {
timeout := time.After(time.Second * 5)
if i%2 == 0 {
select {
case <-timeout:
t.Fatalf("didn't recv pkt after timeout")
case pkt := <-mailBox.PacketOutBox():
recvdPackets2 = append(recvdPackets2, pkt)
}
} else {
select {
case <-mailBox.MessageOutBox():
t.Fatalf("should not receive wire msg after reset")
default:
}
}
}
// The number of packets we received should match the number of unacked
// packets left in the mailbox.
if halfPackets != len(recvdPackets2) {
t.Fatalf("expected %v packets instead got %v", halfPackets,
len(recvdPackets))
}
// Additionally, the set of packets should match exactly with the
// unacked packets, and we should have received the packets in the exact
// same ordering that we added.
if !reflect.DeepEqual(recvdPackets[halfPackets:], recvdPackets2) {
t.Fatalf("recvd packets mismatched: expected %v, got %v",
spew.Sdump(sentPackets), spew.Sdump(recvdPackets))
}
}

@ -4,6 +4,7 @@ import (
"crypto/sha256"
"encoding/binary"
"fmt"
"io/ioutil"
"sync"
"testing"
"time"
@ -120,25 +121,48 @@ type mockServer struct {
var _ Peer = (*mockServer)(nil)
func newMockServer(t testing.TB, name string) *mockServer {
func initSwitchWithDB(db *channeldb.DB) (*Switch, error) {
if db == nil {
tempPath, err := ioutil.TempDir("", "switchdb")
if err != nil {
return nil, err
}
db, err = channeldb.Open(tempPath)
if err != nil {
return nil, err
}
}
return New(Config{
DB: db,
SwitchPackager: channeldb.NewSwitchPackager(),
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
})
}
func newMockServer(t testing.TB, name string, db *channeldb.DB) (*mockServer, error) {
var id [33]byte
h := sha256.Sum256([]byte(name))
copy(id[:], h[:])
return &mockServer{
t: t,
id: id,
name: name,
messages: make(chan lnwire.Message, 3000),
quit: make(chan struct{}),
registry: newMockRegistry(),
htlcSwitch: New(Config{
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
}),
interceptorFuncs: make([]messageInterceptor, 0),
htlcSwitch, err := initSwitchWithDB(db)
if err != nil {
return nil, err
}
return &mockServer{
t: t,
id: id,
name: name,
messages: make(chan lnwire.Message, 3000),
quit: make(chan struct{}),
registry: newMockRegistry(),
htlcSwitch: htlcSwitch,
interceptorFuncs: make([]messageInterceptor, 0),
}, nil
}
func (s *mockServer) Start() error {
@ -196,10 +220,6 @@ type mockHopIterator struct {
hops []ForwardingInfo
}
func (r *mockHopIterator) OnionPacket() *sphinx.OnionPacket {
return nil
}
func newMockHopIterator(hops ...ForwardingInfo) HopIterator {
return &mockHopIterator{hops: hops}
}
@ -261,7 +281,8 @@ type mockObfuscator struct {
ogPacket *sphinx.OnionPacket
}
func newMockObfuscator() ErrorEncrypter {
// NewMockObfuscator initializes a dummy mockObfuscator used for testing.
func NewMockObfuscator() ErrorEncrypter {
return &mockObfuscator{}
}
@ -512,6 +533,10 @@ type mockChannelLink struct {
peer Peer
startMailBox bool
mailBox MailBox
packets chan *htlcPacket
eligible bool
@ -519,6 +544,39 @@ type mockChannelLink struct {
htlcID uint64
}
// completeCircuit is a helper method for adding the finalized payment circuit
// to the switch's circuit map. In testing, this should be executed after
// receiving an htlc from the downstream packets channel.
func (f *mockChannelLink) completeCircuit(pkt *htlcPacket) error {
switch htlc := pkt.htlc.(type) {
case *lnwire.UpdateAddHTLC:
pkt.outgoingChanID = f.shortChanID
pkt.outgoingHTLCID = f.htlcID
htlc.ID = f.htlcID
keystone := Keystone{pkt.inKey(), pkt.outKey()}
if err := f.htlcSwitch.openCircuits(keystone); err != nil {
return err
}
f.htlcID++
case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC:
err := f.htlcSwitch.teardownCircuit(pkt)
if err != nil {
return err
}
}
f.mailBox.AckPacket(pkt.inKey())
return nil
}
func (f *mockChannelLink) deleteCircuit(pkt *htlcPacket) error {
return f.htlcSwitch.deleteCircuits(pkt.inKey())
}
func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID,
shortChanID lnwire.ShortChannelID, peer Peer, eligible bool,
) *mockChannelLink {
@ -527,27 +585,14 @@ func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID,
htlcSwitch: htlcSwitch,
chanID: chanID,
shortChanID: shortChanID,
packets: make(chan *htlcPacket, 1),
peer: peer,
eligible: eligible,
}
}
func (f *mockChannelLink) HandleSwitchPacket(packet *htlcPacket) {
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
f.htlcSwitch.addCircuit(&PaymentCircuit{
PaymentHash: htlc.PaymentHash,
IncomingChanID: packet.incomingChanID,
IncomingHTLCID: packet.incomingHTLCID,
OutgoingChanID: f.shortChanID,
OutgoingHTLCID: f.htlcID,
ErrorEncrypter: packet.obfuscator,
})
f.htlcID++
}
f.packets <- packet
func (f *mockChannelLink) HandleSwitchPacket(pkt *htlcPacket) error {
f.mailBox.AddPacket(pkt)
return nil
}
func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) {
@ -560,12 +605,22 @@ func (f *mockChannelLink) Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSato
return 0, 0, 0
}
func (f *mockChannelLink) AttachMailBox(mailBox MailBox) {
f.mailBox = mailBox
f.packets = mailBox.PacketOutBox()
}
func (f *mockChannelLink) Start() error {
f.mailBox.ResetMessages()
f.mailBox.ResetPackets()
return nil
}
func (f *mockChannelLink) ChanID() lnwire.ChannelID { return f.chanID }
func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID { return f.shortChanID }
func (f *mockChannelLink) UpdateShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid }
func (f *mockChannelLink) Bandwidth() lnwire.MilliSatoshi { return 99999999 }
func (f *mockChannelLink) Peer() Peer { return f.peer }
func (f *mockChannelLink) Start() error { return nil }
func (f *mockChannelLink) Stop() {}
func (f *mockChannelLink) EligibleToForward() bool { return f.eligible }
@ -603,6 +658,10 @@ func (i *mockInvoiceRegistry) SettleInvoice(rhash chainhash.Hash) error {
return fmt.Errorf("can't find mock invoice: %x", rhash[:])
}
if invoice.Terms.Settled {
return nil
}
invoice.Terms.Settled = true
i.invoices[rhash] = invoice

@ -1,6 +1,7 @@
package htlcswitch
import (
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
)
@ -35,6 +36,20 @@ type htlcPacket struct {
// outgoing channel.
outgoingHTLCID uint64
// sourceRef is used by forwarded htlcPackets to locate incoming Add
// entry in a fwdpkg owned by the incoming link. This value can be nil
// if there is no such entry, e.g. switch initiated payments.
sourceRef *channeldb.AddRef
// destRef is used to locate a settle/fail entry in the outgoing link's
// fwdpkg. If sourceRef is non-nil, this reference should be to a
// settle/fail in response to the sourceRef.
destRef *channeldb.SettleFailRef
// incomingAmount is the value in milli-satoshis that arrived on an
// incoming link.
incomingAmount lnwire.MilliSatoshi
// amount is the value of the HTLC that is being created or modified.
amount lnwire.MilliSatoshi
@ -50,10 +65,10 @@ type htlcPacket struct {
// encrypted with any shared secret.
localFailure bool
// isRouted is set to true if the incomingChanID and incomingHTLCID fields
// of a forwarded fail packet are already set and do not need to be looked
// up in the circuit map.
isRouted bool
// hasSource is set to true if the incomingChanID and incomingHTLCID
// fields of a forwarded fail packet are already set and do not need to
// be looked up in the circuit map.
hasSource bool
// isResolution is set to true if this packet was actually an incoming
// resolution message from an outside sub-system. We'll treat these as
@ -61,4 +76,32 @@ type htlcPacket struct {
// encrypt all errors related to this packet as if we were the first
// hop.
isResolution bool
// circuit holds a reference to an Add's circuit which is persisted in
// the switch during successful forwarding.
circuit *PaymentCircuit
}
// inKey returns the circuit key used to identify the incoming htlc.
func (p *htlcPacket) inKey() CircuitKey {
return CircuitKey{
ChanID: p.incomingChanID,
HtlcID: p.incomingHTLCID,
}
}
// outKey returns the circuit key used to identify the outgoing, forwarded htlc.
func (p *htlcPacket) outKey() CircuitKey {
return CircuitKey{
ChanID: p.outgoingChanID,
HtlcID: p.outgoingHTLCID,
}
}
// keystone returns a tuple containing the incoming and outgoing circuit keys.
func (p *htlcPacket) keystone() Keystone {
return Keystone{
InKey: p.inKey(),
OutKey: p.outKey(),
}
}

128
htlcswitch/sequencer.go Normal file

@ -0,0 +1,128 @@
package htlcswitch
import (
"sync"
"github.com/boltdb/bolt"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
)
// defaultSequenceBatchSize specifies the window of sequence numbers that are
// allocated for each write to disk made by the sequencer.
const defaultSequenceBatchSize = 1000
// Sequencer emits sequence numbers for locally initiated HTLCs. These are
// only used internally for tracking pending payments, however they must be
// unique in order to avoid circuit key collision in the circuit map.
type Sequencer interface {
// NextID returns a unique sequence number for each invocation.
NextID() (uint64, error)
}
var (
// nextPaymentIDKey identifies the bucket that will keep track of the
// persistent sequence numbers for payments.
nextPaymentIDKey = []byte("next-payment-id-key")
// ErrSequencerCorrupted signals that the persistence engine was not
// initialized, or has been corrupted since startup.
ErrSequencerCorrupted = errors.New(
"sequencer database has been corrupted")
)
// persistentSequencer is a concrete implementation of IDGenerator, that uses
// channeldb to allocate sequence numbers.
type persistentSequencer struct {
db *channeldb.DB
mu sync.Mutex
nextID uint64
horizonID uint64
}
// NewPersistentSequencer initializes a new sequencer using a channeldb backend.
func NewPersistentSequencer(db *channeldb.DB) (Sequencer, error) {
g := &persistentSequencer{
db: db,
}
// Ensure the database bucket is created before any updates are
// performed.
if err := g.initDB(); err != nil {
return nil, err
}
return g, nil
}
// NextID returns a unique sequence number for every invocation, persisting the
// assignment to avoid reuse.
func (s *persistentSequencer) NextID() (uint64, error) {
// nextID will be the unique sequence number returned if no errors are
// encountered.
var nextID uint64
// If our sequence batch has not been exhausted, we can allocate the
// next identifier in the range.
s.mu.Lock()
defer s.mu.Unlock()
if s.nextID < s.horizonID {
nextID = s.nextID
s.nextID++
return nextID, nil
}
// Otherwise, our sequence batch has been exhausted. We use the last
// known sequence number on disk to mark the beginning of the next
// sequence batch, and allocate defaultSequenceBatchSize (1000) at a
// time.
//
// NOTE: This also will happen on the first invocation after startup,
// i.e. when nextID and horizonID are both 0. The next sequence batch to be
// allocated will start from the last known tip on disk, which is fine
// as we only require uniqueness of the allocated numbers.
var nextHorizonID uint64
if err := s.db.Update(func(tx *bolt.Tx) error {
nextIDBkt := tx.Bucket(nextPaymentIDKey)
if nextIDBkt == nil {
return ErrSequencerCorrupted
}
nextID = nextIDBkt.Sequence()
nextHorizonID = nextID + defaultSequenceBatchSize
// Cannot fail when used in Update.
nextIDBkt.SetSequence(nextHorizonID)
return nil
}); err != nil {
return 0, err
}
// Never assign index zero, to avoid collisions with the EmptyKeystone.
if nextID == 0 {
nextID++
}
// If our batch sequence allocation succeed, update our in-memory values
// so we can continue to allocate sequence numbers without hitting disk.
// The nextID is incremented by one in memory so the in can be used
// issued directly on the next invocation.
s.nextID = nextID + 1
s.horizonID = nextHorizonID
return nextID, nil
}
// initDB populates the bucket used to generate payment sequence numbers.
func (s *persistentSequencer) initDB() error {
return s.db.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(nextPaymentIDKey)
return err
})
}

@ -9,6 +9,7 @@ import (
"crypto/sha256"
"github.com/boltdb/bolt"
"github.com/davecgh/go-spew/spew"
"github.com/roasbeef/btcd/btcec"
@ -26,6 +27,15 @@ var (
// ErrChannelLinkNotFound is used when channel link hasn't been found.
ErrChannelLinkNotFound = errors.New("channel link not found")
// ErrDuplicateAdd signals that the ADD htlc was already forwarded
// through the switch and is locked into another commitment txn.
ErrDuplicateAdd = errors.New("duplicate add HTLC detected")
// ErrIncompleteForward is used when an htlc was already forwarded
// through the switch, but did not get locked into another commitment
// txn.
ErrIncompleteForward = errors.Errorf("incomplete forward detected")
// zeroPreimage is the empty preimage which is returned when we have
// some errors.
zeroPreimage [sha256.Size]byte
@ -39,6 +49,7 @@ type pendingPayment struct {
amount lnwire.MilliSatoshi
preimage chan [sha256.Size]byte
response chan *htlcPacket
err chan error
// deobfuscator is an serializable entity which is used if we received
@ -110,6 +121,15 @@ type Config struct {
// forced unilateral closure of the channel initiated by a local
// subsystem.
LocalChannelClose func(pubKey []byte, request *ChanClose)
// DB is the channeldb instance that will be used to back the switch's
// persistent circuit map.
DB *channeldb.DB
// SwitchPackager provides access to the forwarding packages of all
// active channels. This gives the switch the ability to read arbitrary
// forwarding packages, and ack settles and fails contained within them.
SwitchPackager channeldb.FwdOperator
}
// Switch is the central messaging bus for all incoming/outgoing HTLCs.
@ -136,16 +156,24 @@ type Switch struct {
// integer ID when it is created.
pendingPayments map[uint64]*pendingPayment
pendingMutex sync.RWMutex
nextPendingID uint64
paymentSequencer Sequencer
// circuits is storage for payment circuits which are used to
// forward the settle/fail htlc updates back to the add htlc initiator.
circuits *CircuitMap
circuits CircuitMap
// links is a map of channel id and channel link which manages
// this channel.
linkIndex map[lnwire.ChannelID]ChannelLink
// mailMtx is a read/write mutex that protects the mailboxes map.
mailMtx sync.RWMutex
// mailboxes is a map of channel id to mailboxes, which allows the
// switch to buffer messages for peers that have not come back online.
mailboxes map[lnwire.ShortChannelID]MailBox
// forwardingIndex is an index which is consulted by the switch when it
// needs to locate the next hop to forward an incoming/outgoing HTLC
// update to/from.
@ -185,11 +213,23 @@ type Switch struct {
}
// New creates the new instance of htlc switch.
func New(cfg Config) *Switch {
func New(cfg Config) (*Switch, error) {
circuitMap, err := NewCircuitMap(cfg.DB)
if err != nil {
return nil, err
}
sequencer, err := NewPersistentSequencer(cfg.DB)
if err != nil {
return nil, err
}
return &Switch{
cfg: &cfg,
circuits: NewCircuitMap(),
circuits: circuitMap,
paymentSequencer: sequencer,
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
mailboxes: make(map[lnwire.ShortChannelID]MailBox),
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}),
pendingPayments: make(map[uint64]*pendingPayment),
@ -198,7 +238,7 @@ func New(cfg Config) *Switch {
resolutionMsgs: make(chan *resolutionMsg),
linkControl: make(chan interface{}),
quit: make(chan struct{}),
}
}, nil
}
// resolutionMsg is a struct that wraps an existing ResolutionMsg with a done
@ -246,15 +286,19 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
// able to retrieve it and return response to the user.
payment := &pendingPayment{
err: make(chan error, 1),
response: make(chan *htlcPacket, 1),
preimage: make(chan [sha256.Size]byte, 1),
paymentHash: htlc.PaymentHash,
amount: htlc.Amount,
deobfuscator: deobfuscator,
}
paymentID, err := s.paymentSequencer.NextID()
if err != nil {
return zeroPreimage, err
}
s.pendingMutex.Lock()
paymentID := s.nextPendingID
s.nextPendingID++
s.pendingPayments[paymentID] = payment
s.pendingMutex.Unlock()
@ -262,10 +306,12 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
// this stage it means that packet haven't left boundaries of our
// system and something wrong happened.
packet := &htlcPacket{
incomingChanID: sourceHop,
incomingHTLCID: paymentID,
destNode: nextNode,
htlc: htlc,
}
if err := s.forward(packet); err != nil {
s.removePendingPayment(paymentID)
return zeroPreimage, err
@ -274,7 +320,7 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
// Returns channels so that other subsystem might wait/skip the
// waiting of handling of payment.
var preimage [sha256.Size]byte
var err error
var response *htlcPacket
select {
case e := <-payment.err:
@ -284,6 +330,14 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
"while waiting for payment result")
}
select {
case pkt := <-payment.response:
response = pkt
case <-s.quit:
return zeroPreimage, errors.New("htlc switch have been stopped " +
"while waiting for payment result")
}
select {
case p := <-payment.preimage:
preimage = p
@ -292,6 +346,24 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
"while waiting for payment result")
}
// Remove circuit since we are about to complete an
// add/fail of this HTLC.
if teardownErr := s.teardownCircuit(response); teardownErr != nil {
log.Warnf("unable to teardown circuit %s: %v",
response.inKey(), teardownErr)
return preimage, err
}
// Finally, if this response is contained in a forwarding package, ack
// the settle/fail so that we don't continue to retransmit the HTLC
// internally.
if response.destRef != nil {
if ackErr := s.ackSettleFail(*response.destRef); ackErr != nil {
log.Warnf("unable to ack settle/fail reference: %s: %v",
*response.destRef, ackErr)
}
}
return preimage, err
}
@ -372,6 +444,192 @@ func (s *Switch) updateLinkPolicies(c *updatePoliciesCmd) error {
// update. Also this function is used by channel links itself in order to
// forward the update after it has been included in the channel.
func (s *Switch) forward(packet *htlcPacket) error {
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
circuit := newPaymentCircuit(&htlc.PaymentHash, packet)
actions, err := s.circuits.CommitCircuits(circuit)
if err != nil {
log.Errorf("unable to commit circuit in switch: %v", err)
return err
}
// Drop duplicate packet if it has already been seen.
switch {
case len(actions.Drops) == 1:
return ErrDuplicateAdd
case len(actions.Fails) == 1:
if packet.incomingChanID == sourceHop {
return err
}
failure := lnwire.NewTemporaryChannelFailure(nil)
addErr := ErrIncompleteForward
return s.failAddPacket(packet, failure, addErr)
}
packet.circuit = circuit
}
return s.route(packet)
}
// ForwardPackets adds a list of packets to the switch for processing. Fails and
// settles are added on a first past, simultaneously constructing circuits for
// any adds. After persisting the circuits, another pass of the adds is given to
// forward them through the router.
// NOTE: This method guarantees that the returned err chan will eventually be
// closed. The receiver should read on the channel until receiving such a
// signal.
func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
var (
// fwdChan is a buffered channel used to receive err msgs from
// the htlcPlex when forwarding this batch.
fwdChan = make(chan error, len(packets))
// errChan is a buffered channel returned to the caller, that is
// proxied by the fwdChan. This method guarantees that errChan
// will be closed eventually to alert the receiver that it can
// stop reading from the channel.
errChan = make(chan error, len(packets))
// numSent keeps a running count of how many packets are
// forwarded to the switch, which determines how many responses
// we will wait for on the fwdChan..
numSent int
)
// No packets, nothing to do.
if len(packets) == 0 {
close(errChan)
return errChan
}
// Setup a barrier to prevent the background tasks from processing
// responses until this function returns to the user.
var wg sync.WaitGroup
wg.Add(1)
defer wg.Done()
// Spawn a goroutine the proxy the errs back to the returned err chan.
// This is done to ensure the err chan returned to the caller closed
// properly, alerting the receiver of completion or shutdown.
s.wg.Add(1)
go s.proxyFwdErrs(&numSent, &wg, fwdChan, errChan)
// Make a first pass over the packets, forwarding any settles or fails.
// As adds are found, we create a circuit and append it to our set of
// circuits to be written to disk.
var circuits []*PaymentCircuit
var addBatch []*htlcPacket
for _, packet := range packets {
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
circuit := newPaymentCircuit(&htlc.PaymentHash, packet)
packet.circuit = circuit
circuits = append(circuits, circuit)
addBatch = append(addBatch, packet)
default:
s.routeAsync(packet, fwdChan)
numSent++
}
}
// If this batch did not contain any circuits to commit, we can return
// early.
if len(circuits) == 0 {
return errChan
}
// Write any circuits that we found to disk.
actions, err := s.circuits.CommitCircuits(circuits...)
if err != nil {
log.Errorf("unable to commit circuits in switch: %v", err)
}
// Split the htlc packets by comparing an in-order seek to the head of
// the added, dropped, or failed circuits.
//
// NOTE: This assumes each list is guaranteed to be a subsequence of the
// circuits, and that the union of the sets results in the original set
// of circuits.
var addedPackets, failedPackets []*htlcPacket
for _, packet := range addBatch {
switch {
case len(actions.Adds) > 0 && packet.circuit == actions.Adds[0]:
addedPackets = append(addedPackets, packet)
actions.Adds = actions.Adds[1:]
case len(actions.Drops) > 0 && packet.circuit == actions.Drops[0]:
actions.Drops = actions.Drops[1:]
case len(actions.Fails) > 0 && packet.circuit == actions.Fails[0]:
failedPackets = append(failedPackets, packet)
actions.Fails = actions.Fails[1:]
}
}
// Now, forward any packets for circuits that were successfully added to
// the switch's circuit map.
for _, packet := range addedPackets {
s.routeAsync(packet, fwdChan)
numSent++
}
// Lastly, for any packets that failed, this implies that they were
// left in a half added state, which can happen when recovering from
// failures.
for _, packet := range failedPackets {
failure := lnwire.NewTemporaryChannelFailure(nil)
addErr := errors.Errorf("failing packet after detecting " +
"incomplete forward")
// We don't handle the error here since this method always
// returns an error.
s.failAddPacket(packet, failure, addErr)
}
return errChan
}
// proxyFwdErrs transmits any errors received on `fwdChan` back to `errChan`,
// and guarantees that the `errChan` will be closed after 1) all errors have
// been sent, or 2) the switch has received a shutdown. The `errChan` should be
// buffered with at least the value of `num` after the barrier has been
// released.
//
// NOTE: The receiver of `errChan` should read until the channel closed, since
// this proxying guarantees that the close will happen.
func (s *Switch) proxyFwdErrs(num *int, wg *sync.WaitGroup,
fwdChan, errChan chan error) {
defer s.wg.Done()
defer func() {
close(errChan)
}()
// Wait here until the outer function has finished persisting
// and routing the packets. This guarantees we don't read from num until
// the value is accurate.
wg.Wait()
numSent := *num
for i := 0; i < numSent; i++ {
select {
case err := <-fwdChan:
errChan <- err
case <-s.quit:
log.Errorf("unable to forward htlc packet " +
"htlc switch was stopped")
return
}
}
}
// route sends a single htlcPacket through the switch and synchronously awaits a
// response.
func (s *Switch) route(packet *htlcPacket) error {
command := &plexPacket{
pkt: packet,
err: make(chan error, 1),
@ -387,8 +645,24 @@ func (s *Switch) forward(packet *htlcPacket) error {
case err := <-command.err:
return err
case <-s.quit:
return errors.New("unable to forward htlc packet htlc switch was " +
"stopped")
return errors.New("Htlc Switch was stopped")
}
}
// routeAsync sends a packet through the htlc switch, using the provided err
// chan to propagate errors back to the caller. This method does not wait for
// a response before returning.
func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error) error {
command := &plexPacket{
pkt: packet,
err: errChan,
}
select {
case s.htlcPlex <- command:
return nil
case <-s.quit:
return errors.New("Htlc Switch was stopped")
}
}
@ -405,23 +679,23 @@ func (s *Switch) forward(packet *htlcPacket) error {
// o <-settle-- o <--settle-- o
// Alice Bob Carol
//
func (s *Switch) handleLocalDispatch(packet *htlcPacket) error {
func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
// Pending payments use a special interpretation of the incomingChanID and
// incomingHTLCID fields on packet where the channel ID is blank and the
// HTLC ID is the payment ID. The switch basically views the users of the
// node as a special channel that also offers a sequence of HTLCs.
payment, err := s.findPayment(packet.incomingHTLCID)
payment, err := s.findPayment(pkt.incomingHTLCID)
if err != nil {
return err
}
switch htlc := packet.htlc.(type) {
switch htlc := pkt.htlc.(type) {
// User have created the htlc update therefore we should find the
// appropriate channel link and send the payment over this link.
case *lnwire.UpdateAddHTLC:
// Try to find links by node destination.
links, err := s.getLinks(packet.destNode)
links, err := s.getLinks(pkt.destNode)
if err != nil {
log.Errorf("unable to find links by destination %v", err)
return &ForwardingError{
@ -476,77 +750,25 @@ func (s *Switch) handleLocalDispatch(packet *htlcPacket) error {
// manages then channel.
//
// TODO(roasbeef): should return with an error
packet.outgoingChanID = destination.ShortChanID()
destination.HandleSwitchPacket(packet)
return nil
pkt.outgoingChanID = destination.ShortChanID()
return destination.HandleSwitchPacket(pkt)
// We've just received a settle update which means we can finalize the
// user payment and return successful response.
case *lnwire.UpdateFulfillHTLC:
// Notify the user that his payment was successfully proceed.
payment.err <- nil
payment.response <- pkt
payment.preimage <- htlc.PaymentPreimage
s.removePendingPayment(packet.incomingHTLCID)
s.removePendingPayment(pkt.incomingHTLCID)
// We've just received a fail update which means we can finalize the
// user payment and return fail response.
case *lnwire.UpdateFailHTLC:
var failure *ForwardingError
switch {
// The payment never cleared the link, so we don't need to
// decrypt the error, simply decode it them report back to the
// user.
case packet.localFailure:
var userErr string
r := bytes.NewReader(htlc.Reason)
failureMsg, err := lnwire.DecodeFailure(r, 0)
if err != nil {
userErr = fmt.Sprintf("unable to decode onion failure, "+
"htlc with hash(%x): %v", payment.paymentHash[:], err)
log.Error(userErr)
failureMsg = lnwire.NewTemporaryChannelFailure(nil)
}
failure = &ForwardingError{
ErrorSource: s.cfg.SelfKey,
ExtraMsg: userErr,
FailureMessage: failureMsg,
}
// A payment had to be timed out on chain before it got past
// the first hop. In this case, we'll report a permanent
// channel failure as this means us, or the remote party had to
// go on chain.
case packet.isResolution && htlc.Reason == nil:
userErr := fmt.Sprintf("payment was resolved " +
"on-chain, then cancelled back")
failure = &ForwardingError{
ErrorSource: s.cfg.SelfKey,
ExtraMsg: userErr,
FailureMessage: lnwire.FailPermanentChannelFailure{},
}
// A regular multi-hop payment error that we'll need to
// decrypt.
default:
// We'll attempt to fully decrypt the onion encrypted
// error. If we're unable to then we'll bail early.
failure, err = payment.deobfuscator.DecryptError(htlc.Reason)
if err != nil {
userErr := fmt.Sprintf("unable to de-obfuscate onion failure, "+
"htlc with hash(%x): %v", payment.paymentHash[:], err)
log.Error(userErr)
failure = &ForwardingError{
ErrorSource: s.cfg.SelfKey,
ExtraMsg: userErr,
FailureMessage: lnwire.NewTemporaryChannelFailure(nil),
}
}
}
payment.err <- failure
payment.err <- s.parseFailedPayment(payment, pkt, htlc)
payment.response <- pkt
payment.preimage <- zeroPreimage
s.removePendingPayment(packet.incomingHTLCID)
s.removePendingPayment(pkt.incomingHTLCID)
default:
return errors.New("wrong update type")
@ -555,6 +777,73 @@ func (s *Switch) handleLocalDispatch(packet *htlcPacket) error {
return nil
}
// parseFailedPayment determines the appropriate failure message to return to
// a user initiated payment. The three cases handled are:
// 1) A local failure, which should already plaintext.
// 2) A resolution from the chain arbitrator,
// 3) A failure from the remote party, which will need to be decrypted using the
// payment deobfuscator.
func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket,
htlc *lnwire.UpdateFailHTLC) *ForwardingError {
var failure *ForwardingError
switch {
// The payment never cleared the link, so we don't need to
// decrypt the error, simply decode it them report back to the
// user.
case pkt.localFailure:
var userErr string
r := bytes.NewReader(htlc.Reason)
failureMsg, err := lnwire.DecodeFailure(r, 0)
if err != nil {
userErr = fmt.Sprintf("unable to decode onion failure, "+
"htlc with hash(%x): %v", payment.paymentHash[:], err)
log.Error(userErr)
failureMsg = lnwire.NewTemporaryChannelFailure(nil)
}
failure = &ForwardingError{
ErrorSource: s.cfg.SelfKey,
ExtraMsg: userErr,
FailureMessage: failureMsg,
}
// A payment had to be timed out on chain before it got past
// the first hop. In this case, we'll report a permanent
// channel failure as this means us, or the remote party had to
// go on chain.
case pkt.isResolution && htlc.Reason == nil:
userErr := fmt.Sprintf("payment was resolved " +
"on-chain, then cancelled back")
failure = &ForwardingError{
ErrorSource: s.cfg.SelfKey,
ExtraMsg: userErr,
FailureMessage: lnwire.FailPermanentChannelFailure{},
}
// A regular multi-hop payment error that we'll need to
// decrypt.
default:
var err error
// We'll attempt to fully decrypt the onion encrypted
// error. If we're unable to then we'll bail early.
failure, err = payment.deobfuscator.DecryptError(htlc.Reason)
if err != nil {
userErr := fmt.Sprintf("unable to de-obfuscate onion failure, "+
"htlc with hash(%x): %v", payment.paymentHash[:], err)
log.Error(userErr)
failure = &ForwardingError{
ErrorSource: s.cfg.SelfKey,
ExtraMsg: userErr,
FailureMessage: lnwire.NewTemporaryChannelFailure(nil),
}
}
}
return failure
}
// handlePacketForward is used in cases when we need forward the htlc update
// from one channel link to another and be able to propagate the settle/fail
// updates back. This behaviour is achieved by creation of payment circuits.
@ -565,46 +854,22 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
// payment circuit within our internal state so we can properly forward
// the ultimate settle message back latter.
case *lnwire.UpdateAddHTLC:
if packet.incomingChanID == (lnwire.ShortChannelID{}) {
// A blank incomingChanID indicates that this is a
// pending user-initiated payment.
if packet.incomingChanID == sourceHop {
// A blank incomingChanID indicates that this is
// a pending user-initiated payment.
return s.handleLocalDispatch(packet)
}
source, err := s.getLinkByShortID(packet.incomingChanID)
if err != nil {
err := errors.Errorf("unable to find channel link "+
"by channel point (%v): %v", packet.incomingChanID, err)
log.Error(err)
return err
}
targetLink, err := s.getLinkByShortID(packet.outgoingChanID)
if err != nil {
// If packet was forwarded from another channel link
// than we should notify this link that some error
// occurred.
failure := lnwire.FailUnknownNextPeer{}
reason, err := packet.obfuscator.EncryptFirstHop(failure)
if err != nil {
err := errors.Errorf("unable to obfuscate "+
"error: %v", err)
log.Error(err)
return err
}
source.HandleSwitchPacket(&htlcPacket{
incomingChanID: packet.incomingChanID,
incomingHTLCID: packet.incomingHTLCID,
isRouted: true,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
})
err = errors.Errorf("unable to find link with "+
failure := &lnwire.FailUnknownNextPeer{}
addErr := errors.Errorf("unable to find link with "+
"destination %v", packet.outgoingChanID)
log.Error(err)
return err
return s.failAddPacket(packet, failure, addErr)
}
interfaceLinks, _ := s.getLinks(targetLink.Peer().PubKey())
@ -629,155 +894,277 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
// over has insufficient capacity, then we'll cancel the htlc
// as the payment cannot succeed.
if destination == nil {
// If packet was forwarded from another
// channel link than we should notify this
// link that some error occurred.
// If packet was forwarded from another channel link
// than we should notify this link that some error
// occurred.
failure := lnwire.NewTemporaryChannelFailure(nil)
reason, err := packet.obfuscator.EncryptFirstHop(failure)
if err != nil {
err := errors.Errorf("unable to obfuscate "+
"error: %v", err)
log.Error(err)
return err
}
source.HandleSwitchPacket(&htlcPacket{
incomingChanID: packet.incomingChanID,
incomingHTLCID: packet.incomingHTLCID,
isRouted: true,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
})
err = errors.Errorf("unable to find appropriate "+
addErr := errors.Errorf("unable to find appropriate "+
"channel link insufficient capacity, need "+
"%v", htlc.Amount)
log.Error(err)
return err
return s.failAddPacket(packet, failure, addErr)
}
// Send the packet to the destination channel link which
// manages the channel.
destination.HandleSwitchPacket(packet)
return nil
packet.outgoingChanID = destination.ShortChanID()
return destination.HandleSwitchPacket(packet)
// We've just received a settle packet which means we can finalize the
// payment circuit by forwarding the settle msg to the channel from
// which htlc add packet was initially received.
case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC:
if !packet.isRouted {
// Use circuit map to find the link to forward settle/fail to.
circuit := s.circuits.LookupByHTLC(packet.outgoingChanID,
packet.outgoingHTLCID)
if circuit == nil {
err := errors.Errorf("Unable to find target channel for HTLC "+
"settle/fail: channel ID = %s, HTLC ID = %d",
packet.outgoingChanID, packet.outgoingHTLCID)
log.Error(err)
return err
case *lnwire.UpdateFailHTLC, *lnwire.UpdateFulfillHTLC:
// If the source of this packet has not been set, use the
// circuit map to lookup the origin.
circuit, err := s.closeCircuit(packet)
if err != nil {
return err
}
fail, isFail := htlc.(*lnwire.UpdateFailHTLC)
if isFail && !packet.hasSource {
switch {
case circuit.ErrorEncrypter == nil:
// No message to encrypt, locally sourced
// payment.
case packet.isResolution:
// If this is a resolution message, then we'll need to encrypt
// it as it's actually internally sourced.
var err error
// TODO(roasbeef): don't need to pass actually?
failure := &lnwire.FailPermanentChannelFailure{}
fail.Reason, err = circuit.ErrorEncrypter.EncryptFirstHop(
failure,
)
if err != nil {
err = errors.Errorf("unable to obfuscate "+
"error: %v", err)
log.Error(err)
}
default:
// Otherwise, it's a forwarded error, so we'll perform a
// wrapper encryption as normal.
fail.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
fail.Reason,
)
}
// Remove the circuit since we are about to complete
// the HTLC.
err := s.circuits.Remove(
packet.outgoingChanID,
packet.outgoingHTLCID,
)
if err != nil {
log.Warnf("Failed to close completed onion circuit for %x: "+
"(%s, %d) <-> (%s, %d)", circuit.PaymentHash,
circuit.IncomingChanID, circuit.IncomingHTLCID,
circuit.OutgoingChanID, circuit.OutgoingHTLCID)
} else {
log.Debugf("Closed completed onion circuit for %x: "+
"(%s, %d) <-> (%s, %d)", circuit.PaymentHash,
circuit.IncomingChanID, circuit.IncomingHTLCID,
circuit.OutgoingChanID, circuit.OutgoingHTLCID)
}
packet.incomingChanID = circuit.IncomingChanID
packet.incomingHTLCID = circuit.IncomingHTLCID
} else {
// If this is an HTLC settle, and it wasn't from a
// locally initiated HTLC, then we'll log a forwarding
// event so we can flush it to disk later.
//
// TODO(roasbeef): only do this once link actually
// fully settles?
_, isSettle := packet.htlc.(*lnwire.UpdateFulfillHTLC)
localHTLC := packet.incomingChanID == (lnwire.ShortChannelID{})
if isSettle && !localHTLC {
localHTLC := packet.incomingChanID == sourceHop
if !localHTLC {
s.fwdEventMtx.Lock()
s.pendingFwdingEvents = append(
s.pendingFwdingEvents,
channeldb.ForwardingEvent{
Timestamp: time.Now(),
IncomingChanID: circuit.IncomingChanID,
OutgoingChanID: circuit.OutgoingChanID,
AmtIn: circuit.IncomingAmt,
AmtOut: circuit.OutgoingAmt,
IncomingChanID: circuit.Incoming.ChanID,
OutgoingChanID: circuit.Outgoing.ChanID,
AmtIn: circuit.IncomingAmount,
AmtOut: circuit.OutgoingAmount,
},
)
s.fwdEventMtx.Unlock()
}
// Obfuscate the error message for fail updates before
// sending back through the circuit unless the payment
// was generated locally.
if circuit.ErrorEncrypter != nil {
if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok {
// If this is a resolution message,
// then we'll need to encrypt it as
// it's actually internally sourced.
if packet.isResolution {
// TODO(roasbeef): don't need to pass actually?
failure := &lnwire.FailPermanentChannelFailure{}
htlc.Reason, err = circuit.ErrorEncrypter.EncryptFirstHop(
failure,
)
if err != nil {
err := errors.Errorf("unable to obfuscate "+
"error: %v", err)
log.Error(err)
}
} else {
// Otherwise, it's a forwarded
// error, so we'll perform a
// wrapper encryption as
// normal.
htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
htlc.Reason,
)
}
}
}
}
// For local HTLCs we'll dispatch the settle event back to the
// caller, rather than to the peer that sent us the HTLC
// originally.
localHTLC := packet.incomingChanID == (lnwire.ShortChannelID{})
if localHTLC {
// A blank IncomingChanID in a circuit indicates that it is a pending
// user-initiated payment.
if packet.incomingChanID == sourceHop {
return s.handleLocalDispatch(packet)
}
source, err := s.getLinkByShortID(packet.incomingChanID)
if err != nil {
err := errors.Errorf("Unable to get source channel "+
"link to forward HTLC settle/fail: %v", err)
log.Error(err)
return err
}
source.HandleSwitchPacket(packet)
return nil
// Check to see that the source link is online before removing
// the circuit.
sourceMailbox := s.getOrCreateMailBox(packet.incomingChanID)
return sourceMailbox.AddPacket(packet)
default:
return errors.New("wrong update type")
}
}
// failAddPacket encrypts a fail packet back to an add packet's source.
// The ciphertext will be derived from the failure message proivded by context.
// This method returns the failErr if all other steps complete successfully.
func (s *Switch) failAddPacket(packet *htlcPacket,
failure lnwire.FailureMessage, failErr error) error {
// Encrypt the failure so that the sender will be able to read the error
// message. Since we failed this packet, we use EncryptFirstHop to
// obfuscate the failure for their eyes only.
reason, err := packet.obfuscator.EncryptFirstHop(failure)
if err != nil {
err := errors.Errorf("unable to obfuscate "+
"error: %v", err)
log.Error(err)
return err
}
log.Error(failErr)
// Route a fail packet back to the source link.
sourceMailbox := s.getOrCreateMailBox(packet.incomingChanID)
if err = sourceMailbox.AddPacket(&htlcPacket{
incomingChanID: packet.incomingChanID,
incomingHTLCID: packet.incomingHTLCID,
circuit: packet.circuit,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
}); err != nil {
err = errors.Errorf("source chanid=%v unable to "+
"handle switch packet: %v",
packet.incomingChanID, err)
log.Error(err)
return err
}
return failErr
}
// closeCircuit accepts a settle or fail htlc and the associated htlc packet and
// attempts to determine the source that forwarded this htlc. This method will
// set the incoming chan and htlc ID of the given packet if the source was
// found, and will properly [re]encrypt any failure messages.
func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) {
// If the packet has its source, that means it was failed locally by the
// outgoing link. We fail it here to make sure only one response makes
// it through the switch.
if pkt.hasSource {
circuit, err := s.circuits.FailCircuit(pkt.inKey())
switch err {
// Circuit successfully closed.
case nil:
return circuit, nil
// Circuit was previously closed, but has not been deleted. We'll just
// drop this response until the circuit has been fully removed.
case ErrCircuitClosing:
return nil, err
// Failed to close circuit because it does not exist. This is likely
// because the circuit was already successfully closed. Since
// this packet failed locally, there is no forwarding package
// entry to acknowledge.
case ErrUnknownCircuit:
return nil, err
// Unexpected error.
default:
return nil, err
}
}
// Otherwise, this is packet was received from the remote party.
// Use circuit map to find the incoming link to receive the settle/fail.
circuit, err := s.circuits.CloseCircuit(pkt.outKey())
switch err {
// Open circuit successfully closed.
case nil:
pkt.incomingChanID = circuit.Incoming.ChanID
pkt.incomingHTLCID = circuit.Incoming.HtlcID
pkt.circuit = circuit
pkt.sourceRef = &circuit.AddRef
return circuit, nil
// Circuit was previously closed, but has not been deleted. We'll just
// drop this response until the circuit has been removed.
case ErrCircuitClosing:
return nil, err
// Failed to close circuit because it does not exist. This is likely
// because the circuit was already successfully closed.
case ErrUnknownCircuit:
err := errors.Errorf("Unable to find target channel "+
"for HTLC settle/fail: channel ID = %s, "+
"HTLC ID = %d", pkt.outgoingChanID,
pkt.outgoingHTLCID)
log.Error(err)
// TODO(conner): ack settle/fail
if pkt.destRef != nil {
if err := s.ackSettleFail(*pkt.destRef); err != nil {
return nil, err
}
}
return nil, err
// Unexpected error.
default:
return nil, err
}
}
func (s *Switch) ackSettleFail(settleFailRef channeldb.SettleFailRef) error {
return s.cfg.DB.Update(func(tx *bolt.Tx) error {
return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRef)
})
}
// teardownCircuit removes a pending or open circuit from the switch's circuit
// map and prints useful logging statements regarding the outcome.
func (s *Switch) teardownCircuit(pkt *htlcPacket) error {
var pktType string
switch htlc := pkt.htlc.(type) {
case *lnwire.UpdateFulfillHTLC:
pktType = "SETTLE"
case *lnwire.UpdateFailHTLC:
pktType = "FAIL"
default:
err := fmt.Errorf("cannot tear down packet of type: %T", htlc)
log.Errorf(err.Error())
return err
}
switch {
case pkt.circuit.HasKeystone():
log.Debugf("Tearing down open circuit with %s pkt, removing circuit=%v "+
"with keystone=%v", pktType, pkt.inKey(), pkt.outKey())
err := s.circuits.DeleteCircuits(pkt.inKey())
if err != nil {
log.Warnf("Failed to tear down open circuit (%s, %d) <-> (%s, %d) "+
"with payment_hash-%v using %s pkt",
pkt.incomingChanID, pkt.incomingHTLCID,
pkt.outgoingChanID, pkt.outgoingHTLCID,
pkt.circuit.PaymentHash, pktType)
return err
}
log.Debugf("Closed completed %s circuit for %x: "+
"(%s, %d) <-> (%s, %d)", pktType, pkt.circuit.PaymentHash,
pkt.incomingChanID, pkt.incomingHTLCID,
pkt.outgoingChanID, pkt.outgoingHTLCID)
default:
log.Debugf("Tearing down incomplete circuit with %s for inkey=%v",
pktType, pkt.inKey())
err := s.circuits.DeleteCircuits(pkt.inKey())
if err != nil {
log.Warnf("Failed to tear down pending %s circuit for %x: "+
"(%s, %d)", pktType, pkt.circuit.PaymentHash,
pkt.incomingChanID, pkt.incomingHTLCID)
return err
}
log.Debugf("Removed pending onion circuit for %x: "+
"(%s, %d)", pkt.circuit.PaymentHash,
pkt.incomingChanID, pkt.incomingHTLCID)
}
return nil
}
// CloseLink creates and sends the close channel command to the target link
// directing the specified closure type. If the closure type if CloseRegular,
// then the last parameter should be the ideal fee-per-kw that will be used as
@ -918,7 +1305,10 @@ func (s *Switch) htlcForwarder() {
// collect all the forwarding events since the last internal,
// and write them out to our log.
case <-fwdEventTicker.C:
s.wg.Add(1)
go func() {
defer s.wg.Done()
if err := s.FlushForwardingEvents(); err != nil {
log.Errorf("unable to flush "+
"forwarding events: %v", err)
@ -1029,9 +1419,151 @@ func (s *Switch) Start() error {
s.wg.Add(1)
go s.htlcForwarder()
if err := s.reforwardResponses(); err != nil {
log.Errorf("unable to reforward responses: %v", err)
return err
}
return nil
}
// reforwardResponses for every known, non-pending channel, loads all associated
// forwarding packages and reforwards any Settle or Fail HTLCs found. This is
// used to resurrect the switch's mailboxes after a restart.
func (s *Switch) reforwardResponses() error {
activeChannels, err := s.cfg.DB.FetchAllChannels()
if err != nil {
return err
}
for _, activeChannel := range activeChannels {
if activeChannel.IsPending {
continue
}
shortChanID := activeChannel.ShortChanID
fwdPkgs, err := s.loadChannelFwdPkgs(shortChanID)
if err != nil {
return err
}
s.reforwardSettleFails(fwdPkgs)
}
return nil
}
// loadChannelFwdPkgs loads all forwarding packages owned by the `source` short
// channel identifier.
func (s *Switch) loadChannelFwdPkgs(
source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) {
var fwdPkgs []*channeldb.FwdPkg
if err := s.cfg.DB.Update(func(tx *bolt.Tx) error {
var err error
fwdPkgs, err = s.cfg.SwitchPackager.LoadChannelFwdPkgs(
tx, source,
)
return err
}); err != nil {
return nil, err
}
return fwdPkgs, nil
}
// reforwardSettleFails parses the Settle and Fail HTLCs from the list of
// forwarding packages, and reforwards those that have not been acknowledged.
// This is intended to occur on startup, in order to recover the switch's
// mailboxes, and to ensure that responses can be propagated in case the
// outgoing link never comes back online.
//
// NOTE: This should mimic the behavior processRemoteSettleFails.
func (s *Switch) reforwardSettleFails(fwdPkgs []*channeldb.FwdPkg) {
for _, fwdPkg := range fwdPkgs {
settleFails := lnwallet.PayDescsFromRemoteLogUpdates(
fwdPkg.Source, fwdPkg.Height, fwdPkg.SettleFails,
)
switchPackets := make([]*htlcPacket, 0, len(settleFails))
for i, pd := range settleFails {
// Skip any settles or fails that have already been
// acknowledged by the incoming link that originated the
// forwarded Add.
if fwdPkg.SettleFailFilter.Contains(uint16(i)) {
continue
}
switch pd.EntryType {
// A settle for an HTLC we previously forwarded HTLC has
// been received. So we'll forward the HTLC to the
// switch which will handle propagating the settle to
// the prior hop.
case lnwallet.Settle:
settlePacket := &htlcPacket{
outgoingChanID: fwdPkg.Source,
outgoingHTLCID: pd.ParentIndex,
destRef: pd.DestRef,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: pd.RPreimage,
},
}
// Add the packet to the batch to be forwarded, and
// notify the overflow queue that a spare spot has been
// freed up within the commitment state.
switchPackets = append(switchPackets, settlePacket)
// A failureCode message for a previously forwarded HTLC has been
// received. As a result a new slot will be freed up in our
// commitment state, so we'll forward this to the switch so the
// backwards undo can continue.
case lnwallet.Fail:
// Fetch the reason the HTLC was cancelled so we can
// continue to propagate it.
failPacket := &htlcPacket{
outgoingChanID: fwdPkg.Source,
outgoingHTLCID: pd.ParentIndex,
destRef: pd.DestRef,
htlc: &lnwire.UpdateFailHTLC{
Reason: lnwire.OpaqueReason(pd.FailReason),
},
}
// Add the packet to the batch to be forwarded, and
// notify the overflow queue that a spare spot has been
// freed up within the commitment state.
switchPackets = append(switchPackets, failPacket)
}
}
errChan := s.ForwardPackets(switchPackets...)
go handleBatchFwdErrs(errChan)
}
}
// handleBatchFwdErrs waits on the given errChan until it is closed, logging the
// errors returned from any unsuccessful forwarding attempts.
func handleBatchFwdErrs(errChan chan error) {
for {
err, ok := <-errChan
if !ok {
// Err chan has been drained or switch is shutting down.
// Either way, return.
return
}
if err == nil {
continue
}
log.Errorf("unhandled error while reforwarding htlc "+
"settle/fail over htlcswitch: %v", err)
}
}
// Stop gracefully stops all active helper goroutines, then waits until they've
// exited.
func (s *Switch) Stop() error {
@ -1043,6 +1575,11 @@ func (s *Switch) Stop() error {
log.Infof("HTLC Switch shutting down")
close(s.quit)
for _, mailBox := range s.mailboxes {
mailBox.Stop()
}
s.wg.Wait()
return nil
@ -1096,6 +1633,14 @@ func (s *Switch) addLink(link ChannelLink) error {
}
s.interfaceIndex[peerPub][link] = struct{}{}
// Get the mailbox for this link, which buffers packets in case there
// packets that we tried to deliver while this link was offline.
mailbox := s.getOrCreateMailBox(link.ShortChanID())
// Give the link its mailbox, we only need to start the mailbox if it
// wasn't previously found.
link.AttachMailBox(mailbox)
if err := link.Start(); err != nil {
s.removeLink(link.ChanID())
return err
@ -1107,6 +1652,32 @@ func (s *Switch) addLink(link ChannelLink) error {
return nil
}
// getOrCreateMailBox returns the known mailbox for a particular short channel
// id, or creates one if the link has no existing mailbox.
func (s *Switch) getOrCreateMailBox(chanID lnwire.ShortChannelID) MailBox {
// Check to see if we have a mailbox already populated for this link.
s.mailMtx.RLock()
mailbox, ok := s.mailboxes[chanID]
if ok {
s.mailMtx.RUnlock()
return mailbox
}
s.mailMtx.RUnlock()
// Otherwise, we will make a new one only if the mailbox still is not
// present after the exclusive mutex is acquired.
s.mailMtx.Lock()
mailbox, ok = s.mailboxes[chanID]
if !ok {
mailbox = newMemoryMailBox()
mailbox.Start()
s.mailboxes[chanID] = mailbox
}
s.mailMtx.Unlock()
return mailbox
}
// getLinkCmd is a get link command wrapper, it is used to propagate handler
// parameters and return handler error.
type getLinkCmd struct {
@ -1361,15 +1932,47 @@ func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) {
return payment, nil
}
// CircuitModifier returns a reference to subset of the interfaces provided by
// the circuit map, to allow links to open and close circuits.
func (s *Switch) CircuitModifier() CircuitModifier {
return s.circuits
}
// numPendingPayments is helper function which returns the overall number of
// pending user payments.
func (s *Switch) numPendingPayments() int {
return len(s.pendingPayments)
}
// addCircuit adds a circuit to the switch's in-memory mapping.
func (s *Switch) addCircuit(circuit *PaymentCircuit) {
s.circuits.Add(circuit)
// commitCircuits persistently adds a circuit to the switch's circuit map.
func (s *Switch) commitCircuits(circuits ...*PaymentCircuit) (
*CircuitFwdActions, error) {
return s.circuits.CommitCircuits(circuits...)
}
// openCircuits preemptively writes the keystones for Adds that are about to be
// added to a commitment txn.
func (s *Switch) openCircuits(keystones ...Keystone) error {
return s.circuits.OpenCircuits(keystones...)
}
// deleteCircuits persistently removes the circuit, and keystone if present,
// from the circuit map.
func (s *Switch) deleteCircuits(inKeys ...CircuitKey) error {
return s.circuits.DeleteCircuits(inKeys...)
}
// lookupCircuit queries the in memory representation of the circuit map to
// retrieve a particular circuit.
func (s *Switch) lookupCircuit(inKey CircuitKey) *PaymentCircuit {
return s.circuits.LookupCircuit(inKey)
}
// lookupOpenCircuit queries the in-memory representation of the circuit map for a
// circuit whose outgoing circuit key matches outKey.
func (s *Switch) lookupOpenCircuit(outKey CircuitKey) *PaymentCircuit {
return s.circuits.LookupOpenCircuit(outKey)
}
// FlushForwardingEvents flushes out the set of pending forwarding events to

@ -1,8 +1,10 @@
package htlcswitch
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"io"
"io/ioutil"
"testing"
"time"
@ -11,39 +13,41 @@ import (
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/chaincfg/chainhash"
"github.com/roasbeef/btcd/wire"
"github.com/roasbeef/btcutil"
)
var (
hash1, _ = chainhash.NewHash(bytes.Repeat([]byte("a"), 32))
hash2, _ = chainhash.NewHash(bytes.Repeat([]byte("b"), 32))
chanPoint1 = wire.NewOutPoint(hash1, 0)
chanPoint2 = wire.NewOutPoint(hash2, 0)
chanID1 = lnwire.NewChanIDFromOutPoint(chanPoint1)
chanID2 = lnwire.NewChanIDFromOutPoint(chanPoint2)
aliceChanID = lnwire.NewShortChanIDFromInt(1)
bobChanID = lnwire.NewShortChanIDFromInt(2)
)
func genPreimage() ([32]byte, error) {
var preimage [32]byte
if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil {
return preimage, err
}
return preimage, nil
}
// TestSwitchForward checks the ability of htlc switch to forward add/settle
// requests.
func TestSwitchForward(t *testing.T) {
t.Parallel()
alicePeer := newMockServer(t, "alice")
bobPeer := newMockServer(t, "bob")
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", nil)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
s := New(Config{
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
})
s.Start()
s, err := initSwitchWithDB(nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
@ -60,13 +64,16 @@ func TestSwitchForward(t *testing.T) {
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
packet := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: newMockObfuscator(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
@ -80,12 +87,14 @@ func TestSwitchForward(t *testing.T) {
select {
case <-bobChannelLink.packets:
break
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.pending() != 1 {
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
@ -107,17 +116,953 @@ func TestSwitchForward(t *testing.T) {
}
select {
case <-aliceChannelLink.packets:
break
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.deleteCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.pending() != 0 {
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", nil)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.forward(ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case packet := <-bobChannelLink.packets:
// Complete the payment circuit and assign the outgoing htlc id
// before restarting.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s2, err := initSwitchWithDB(cdb2)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, bobPeer, true,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Craft a failure message from the remote peer.
fail := &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFailHTLC{},
}
// Send the fail packet from the remote peer through the switch.
if err := s2.forward(fail); err != nil {
t.Fatalf(err.Error())
}
// Pull packet from alice's link, as it should have gone through
// successfully.
select {
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Circuit map should be empty now.
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Send the fail packet from the remote peer through the switch.
if err := s2.forward(fail); err == nil {
t.Fatalf("expected failure when sending duplicate fail " +
"with no pending circuit")
}
}
func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", nil)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.forward(ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case packet := <-bobChannelLink.packets:
// Complete the payment circuit and assign the outgoing htlc id
// before restarting.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s2, err := initSwitchWithDB(cdb2)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, bobPeer, true,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of circuits")
}
// Craft a settle message from the remote peer.
settle := &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
// Send the settle packet from the remote peer through the switch.
if err := s2.forward(settle); err != nil {
t.Fatalf(err.Error())
}
// Pull packet from alice's link, as it should have gone through
// successfully.
select {
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete circuit with in key=%s: %v",
packet.inKey(), err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Circuit map should be empty now.
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Send the settle packet again, which should fail.
if err := s2.forward(settle); err == nil {
t.Fatalf("expected failure when sending duplicate settle " +
"with no pending circuit")
}
}
func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", nil)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.forward(ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case packet := <-bobChannelLink.packets:
// Complete the payment circuit and assign the outgoing htlc id
// before restarting.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s2, err := initSwitchWithDB(cdb2)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, bobPeer, true,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatalf("wrong amount of half circuits")
}
// Resend the failed htlc, it should be returned to alice since the
// switch will detect that it has been half added previously.
err = s2.forward(ogPacket)
if err != ErrDuplicateAdd {
t.Fatal("unexpected error when reforwarding a "+
"failed packet", err)
}
// After detecting an incomplete forward, the fail packet should have
// been returned to the sender.
select {
case <-aliceChannelLink.packets:
t.Fatal("request should not have returned to source")
case <-bobChannelLink.packets:
t.Fatal("request should not have forwarded to destination")
case <-time.After(time.Second):
}
}
func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", nil)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.forward(ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Pull packet from bob's link, but do not perform a full add.
select {
case <-bobChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
// Now we will restart bob, leaving the forwarding decision for this
// htlc is in the half-added state.
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s2, err := initSwitchWithDB(cdb2)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, bobPeer, true,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Resend the failed htlc, it should be returned to alice since the
// switch will detect that it has been half added previously.
err = s2.forward(ogPacket)
if err != ErrIncompleteForward {
t.Fatal("unexpected error when reforwarding a "+
"failed packet", err)
}
// After detecting an incomplete forward, the fail packet should have
// been returned to the sender.
select {
case <-aliceChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
}
// TestSwitchForwardCircuitPersistence checks the ability of htlc switch to
// maintain the proper entries in the circuit map in the face of restarts.
func TestSwitchForwardCircuitPersistence(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", nil)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
// Even though we intend to Stop s later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
if s.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Handle the request and checks that bob channel link received it.
if err := s.forward(ogPacket); err != nil {
t.Fatal(err)
}
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
// Retrieve packet from outgoing link and cache until after restart.
var packet *htlcPacket
select {
case packet = <-bobChannelLink.packets:
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
if err := cdb.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb2, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s2, err := initSwitchWithDB(cdb2)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s2.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
// Even though we intend to Stop s2 later in the test, it is safe to
// defer this Stop since its execution it is protected by an atomic
// guard, guaranteeing it executes at most once.
defer s2.Stop()
aliceChannelLink = newMockChannelLink(
s2, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s2, chanID2, bobChanID, bobPeer, true,
)
if err := s2.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s2.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of half circuits")
}
// Now that the switch has restarted, complete the payment circuit.
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
if s2.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s2.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
// Create settle request pretending that bob link handled the add htlc
// request and sent the htlc settle request back. This request should
// be forwarder back to Alice link.
ogPacket = &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
// Handle the request and checks that payment circuit works properly.
if err := s2.forward(ogPacket); err != nil {
t.Fatal(err)
}
select {
case packet = <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete circuit with in key=%s: %v",
packet.inKey(), err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s2.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits, want 1, got %d",
s2.circuits.NumPending())
}
if s2.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
if err := s2.Stop(); err != nil {
t.Fatal(err)
}
if err := cdb2.Close(); err != nil {
t.Fatalf(err.Error())
}
cdb3, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to reopen channeldb: %v", err)
}
s3, err := initSwitchWithDB(cdb3)
if err != nil {
t.Fatalf("unable reinit switch: %v", err)
}
if err := s3.Start(); err != nil {
t.Fatalf("unable to restart switch: %v", err)
}
defer s3.Stop()
aliceChannelLink = newMockChannelLink(
s3, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink = newMockChannelLink(
s3, chanID2, bobChanID, bobPeer, true,
)
if err := s3.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s3.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
if s3.circuits.NumPending() != 0 {
t.Fatalf("wrong amount of half circuits")
}
if s3.circuits.NumOpen() != 0 {
t.Fatalf("wrong amount of circuits")
}
}
// TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes
// along, then we won't attempt to froward it down al ink that isn't yet able
// to forward any HTLC's.
@ -126,15 +1071,25 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
var packet *htlcPacket
alicePeer := newMockServer(t, "alice")
bobPeer := newMockServer(t, "bob")
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", nil)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
s := New(Config{
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
})
s.Start()
s, err := initSwitchWithDB(nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
@ -165,16 +1120,16 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
PaymentHash: rhash,
Amount: 1,
},
obfuscator: newMockObfuscator(),
obfuscator: NewMockObfuscator(),
}
// The request to forward should fail as
err := s.forward(packet)
err = s.forward(packet)
if err == nil {
t.Fatalf("forwarding should have failed due to inactive link")
}
if s.circuits.pending() != 0 {
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
@ -186,14 +1141,21 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) {
// We'll create a single link for this test, marking it as being unable
// to forward form the get go.
alicePeer := newMockServer(t, "alice")
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
s := New(Config{
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
})
s.Start()
s, err := initSwitchWithDB(nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, false,
@ -202,7 +1164,10 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) {
t.Fatalf("unable to add alice link: %v", err)
}
preimage := [sha256.Size]byte{1}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
addMsg := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
@ -213,12 +1178,12 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) {
// outgoing link. This should fail as Alice isn't yet able to forward
// any active HTLC's.
alicePub := aliceChannelLink.Peer().PubKey()
_, err := s.SendHTLC(alicePub, addMsg, nil)
_, err = s.SendHTLC(alicePub, addMsg, nil)
if err == nil {
t.Fatalf("local forward should fail due to inactive link")
}
if s.circuits.pending() != 0 {
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
@ -228,15 +1193,25 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) {
func TestSwitchCancel(t *testing.T) {
t.Parallel()
alicePeer := newMockServer(t, "alice")
bobPeer := newMockServer(t, "bob")
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", nil)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
s := New(Config{
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
})
s.Start()
s, err := initSwitchWithDB(nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
@ -253,13 +1228,16 @@ func TestSwitchCancel(t *testing.T) {
// Create request which should be forwarder from alice channel link
// to bob channel link.
preimage := [sha256.Size]byte{1}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
request := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: newMockObfuscator(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
@ -272,13 +1250,19 @@ func TestSwitchCancel(t *testing.T) {
}
select {
case <-bobChannelLink.packets:
break
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.pending() != 1 {
if s.circuits.NumPending() != 1 {
t.Fatalf("wrong amount of half circuits")
}
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
@ -298,13 +1282,19 @@ func TestSwitchCancel(t *testing.T) {
}
select {
case <-aliceChannelLink.packets:
break
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.pending() != 0 {
if s.circuits.NumPending() != 0 {
t.Fatal("wrong amount of circuits")
}
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
@ -314,15 +1304,25 @@ func TestSwitchCancel(t *testing.T) {
func TestSwitchAddSamePayment(t *testing.T) {
t.Parallel()
alicePeer := newMockServer(t, "alice")
bobPeer := newMockServer(t, "bob")
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
s := New(Config{
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
})
s.Start()
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", nil)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
s, err := initSwitchWithDB(nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
@ -339,13 +1339,16 @@ func TestSwitchAddSamePayment(t *testing.T) {
// Create request which should be forwarder from alice channel link
// to bob channel link.
preimage := [sha256.Size]byte{1}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
request := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: newMockObfuscator(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
@ -358,13 +1361,16 @@ func TestSwitchAddSamePayment(t *testing.T) {
}
select {
case <-bobChannelLink.packets:
break
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.pending() != 1 {
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
@ -372,7 +1378,7 @@ func TestSwitchAddSamePayment(t *testing.T) {
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 1,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: newMockObfuscator(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
@ -384,7 +1390,17 @@ func TestSwitchAddSamePayment(t *testing.T) {
t.Fatal(err)
}
if s.circuits.pending() != 2 {
select {
case packet := <-bobChannelLink.packets:
if err := bobChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to destination")
}
if s.circuits.NumOpen() != 2 {
t.Fatal("wrong amount of circuits")
}
@ -404,13 +1420,16 @@ func TestSwitchAddSamePayment(t *testing.T) {
}
select {
case <-aliceChannelLink.packets:
break
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.pending() != 1 {
if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits")
}
@ -427,13 +1446,16 @@ func TestSwitchAddSamePayment(t *testing.T) {
}
select {
case <-aliceChannelLink.packets:
break
case pkt := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(pkt); err != nil {
t.Fatalf("unable to remove circuit: %v", err)
}
case <-time.After(time.Second):
t.Fatal("request was not propagated to channelPoint")
}
if s.circuits.pending() != 0 {
if s.circuits.NumOpen() != 0 {
t.Fatal("wrong amount of circuits")
}
}
@ -443,14 +1465,21 @@ func TestSwitchAddSamePayment(t *testing.T) {
func TestSwitchSendPayment(t *testing.T) {
t.Parallel()
alicePeer := newMockServer(t, "alice")
alicePeer, err := newMockServer(t, "alice", nil)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
s := New(Config{
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
})
s.Start()
s, err := initSwitchWithDB(nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer s.Stop()
chanID1, _, aliceChanID, _ := genIDs()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
@ -461,7 +1490,10 @@ func TestSwitchSendPayment(t *testing.T) {
// Create request which should be forwarder from alice channel link
// to bob channel link.
preimage := [sha256.Size]byte{1}
preimage, err := genPreimage()
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
update := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
@ -485,8 +1517,11 @@ func TestSwitchSendPayment(t *testing.T) {
}()
select {
case <-aliceChannelLink.packets:
break
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case err := <-errChan:
t.Fatalf("unable to send payment: %v", err)
case <-time.After(time.Second):
@ -494,8 +1529,11 @@ func TestSwitchSendPayment(t *testing.T) {
}
select {
case <-aliceChannelLink.packets:
break
case packet := <-aliceChannelLink.packets:
if err := aliceChannelLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case err := <-errChan:
t.Fatalf("unable to send payment: %v", err)
case <-time.After(time.Second):
@ -506,14 +1544,14 @@ func TestSwitchSendPayment(t *testing.T) {
t.Fatal("wrong amount of pending payments")
}
if s.circuits.pending() != 2 {
if s.circuits.NumOpen() != 2 {
t.Fatal("wrong amount of circuits")
}
// Create fail request pretending that bob channel link handled
// the add htlc request with error and sent the htlc fail request
// back. This request should be forwarded back to alice channel link.
obfuscator := newMockObfuscator()
obfuscator := NewMockObfuscator()
failure := lnwire.FailIncorrectPaymentAmount{}
reason, err := obfuscator.EncryptFirstHop(failure)
if err != nil {

@ -4,14 +4,17 @@ import (
"bytes"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"fmt"
"io/ioutil"
"math/big"
"net"
"os"
"sync/atomic"
"testing"
"time"
"github.com/boltdb/bolt"
"github.com/btcsuite/fastsha256"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lightning-onion"
@ -53,8 +56,35 @@ var (
"3135609736119018462340006816851118", 10)
)
// mockGetChanUpdateMessage helper function which returns topology update of
// the channel
var idSeqNum uint64
func genIDs() (lnwire.ChannelID, lnwire.ChannelID, lnwire.ShortChannelID,
lnwire.ShortChannelID) {
id := atomic.AddUint64(&idSeqNum, 2)
var scratch [8]byte
binary.BigEndian.PutUint64(scratch[:], id)
hash1, _ := chainhash.NewHash(bytes.Repeat(scratch[:], 4))
binary.BigEndian.PutUint64(scratch[:], id+1)
hash2, _ := chainhash.NewHash(bytes.Repeat(scratch[:], 4))
chanPoint1 := wire.NewOutPoint(hash1, uint32(id))
chanPoint2 := wire.NewOutPoint(hash2, uint32(id+1))
chanID1 := lnwire.NewChanIDFromOutPoint(chanPoint1)
chanID2 := lnwire.NewChanIDFromOutPoint(chanPoint2)
aliceChanID := lnwire.NewShortChanIDFromInt(id)
bobChanID := lnwire.NewShortChanIDFromInt(id + 1)
return chanID1, chanID2, aliceChanID, bobChanID
}
// mockGetChanUpdateMessage helper function which returns topology update
// of the channel
func mockGetChanUpdateMessage() (*lnwire.ChannelUpdate, error) {
return &lnwire.ChannelUpdate{
Signature: wireSig,
@ -293,6 +323,8 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
}
cleanUpFunc := func() {
dbAlice.Close()
dbBob.Close()
os.RemoveAll(bobPath)
os.RemoveAll(alicePath)
}
@ -339,7 +371,21 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
restore := func() (*lnwallet.LightningChannel, *lnwallet.LightningChannel,
error) {
aliceStoredChannels, err := dbAlice.FetchOpenChannels(aliceKeyPub)
if err != nil {
switch err {
case nil:
case bolt.ErrDatabaseNotOpen:
dbAlice, err = channeldb.Open(dbAlice.Path())
if err != nil {
return nil, nil, errors.Errorf("unable to reopen alice "+
"db: %v", err)
}
aliceStoredChannels, err = dbAlice.FetchOpenChannels(aliceKeyPub)
if err != nil {
return nil, nil, errors.Errorf("unable to fetch alice "+
"channel: %v", err)
}
default:
return nil, nil, errors.Errorf("unable to fetch alice channel: "+
"%v", err)
}
@ -364,7 +410,21 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
}
bobStoredChannels, err := dbBob.FetchOpenChannels(bobKeyPub)
if err != nil {
switch err {
case nil:
case bolt.ErrDatabaseNotOpen:
dbBob, err = channeldb.Open(dbBob.Path())
if err != nil {
return nil, nil, errors.Errorf("unable to reopen bob "+
"db: %v", err)
}
bobStoredChannels, err = dbBob.FetchOpenChannels(bobKeyPub)
if err != nil {
return nil, nil, errors.Errorf("unable to fetch bob "+
"channel: %v", err)
}
default:
return nil, nil, errors.Errorf("unable to fetch bob channel: "+
"%v", err)
}
@ -689,8 +749,7 @@ type clusterChannels struct {
func createClusterChannels(aliceToBob, bobToCarol btcutil.Amount) (
*clusterChannels, func(), func() (*clusterChannels, error), error) {
firstChanID := lnwire.NewShortChanIDFromInt(4)
secondChanID := lnwire.NewShortChanIDFromInt(5)
_, _, firstChanID, secondChanID := genIDs()
// Create lightning channels between Alice<->Bob and Bob<->Carol
aliceChannel, firstBobChannel, cleanAliceBob, restoreAliceBob, err :=
@ -759,14 +818,29 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
secondBobChannel, carolChannel *lnwallet.LightningChannel,
startingHeight uint32) *threeHopNetwork {
aliceDb := aliceChannel.State().Db
bobDb := firstBobChannel.State().Db
carolDb := carolChannel.State().Db
// Create three peers/servers.
aliceServer := newMockServer(t, "alice")
bobServer := newMockServer(t, "bob")
carolServer := newMockServer(t, "carol")
aliceServer, err := newMockServer(t, "alice", aliceDb)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobServer, err := newMockServer(t, "bob", bobDb)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
carolServer, err := newMockServer(t, "carol", carolDb)
if err != nil {
t.Fatalf("unable to create carol server: %v", err)
}
// Create mock decoder instead of sphinx one in order to mock the route
// which htlc should follow.
decoder := &mockIteratorDecoder{}
aliceDecoder := newMockIteratorDecoder()
bobDecoder := newMockIteratorDecoder()
carolDecoder := newMockIteratorDecoder()
feeEstimator := &mockFeeEstimator{
byteFeeIn: make(chan lnwallet.SatPerVByte),
@ -783,7 +857,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
BaseFee: lnwire.NewMSatFromSatoshis(1),
TimeLockDelta: 6,
}
obfuscator := newMockObfuscator()
obfuscator := NewMockObfuscator()
aliceEpochChan := make(chan *chainntnfs.BlockEpoch)
aliceEpoch := &chainntnfs.BlockEpochEvent{
@ -794,12 +868,14 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
aliceTicker := time.NewTicker(50 * time.Millisecond)
aliceChannelLink := NewChannelLink(
ChannelLinkConfig{
FwrdingPolicy: globalPolicy,
Peer: bobServer,
Switch: aliceServer.htlcSwitch,
DecodeHopIterator: decoder.DecodeHopIterator,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (ErrorEncrypter,
lnwire.FailCode) {
FwrdingPolicy: globalPolicy,
Peer: bobServer,
Circuits: aliceServer.htlcSwitch.CircuitModifier(),
ForwardPackets: aliceServer.htlcSwitch.ForwardPackets,
DecodeHopIterator: aliceDecoder.DecodeHopIterator,
DecodeHopIterators: aliceDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (
ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone
},
GetLastChannelUpdate: mockGetChanUpdateMessage,
@ -810,10 +886,11 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
UpdateContractSignals: func(*contractcourt.ContractSignals) error {
return nil
},
ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true,
BatchTicker: &mockTicker{aliceTicker.C},
BatchSize: 10,
ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true,
BatchTicker: &mockTicker{aliceTicker.C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(5 * time.Second).C},
BatchSize: 10,
},
aliceChannel,
startingHeight,
@ -840,12 +917,14 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
firstBobTicker := time.NewTicker(50 * time.Millisecond)
firstBobChannelLink := NewChannelLink(
ChannelLinkConfig{
FwrdingPolicy: globalPolicy,
Peer: aliceServer,
Switch: bobServer.htlcSwitch,
DecodeHopIterator: decoder.DecodeHopIterator,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (ErrorEncrypter,
lnwire.FailCode) {
FwrdingPolicy: globalPolicy,
Peer: aliceServer,
Circuits: bobServer.htlcSwitch.CircuitModifier(),
ForwardPackets: bobServer.htlcSwitch.ForwardPackets,
DecodeHopIterator: bobDecoder.DecodeHopIterator,
DecodeHopIterators: bobDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (
ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone
},
GetLastChannelUpdate: mockGetChanUpdateMessage,
@ -856,10 +935,11 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
UpdateContractSignals: func(*contractcourt.ContractSignals) error {
return nil
},
ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true,
BatchTicker: &mockTicker{firstBobTicker.C},
BatchSize: 10,
ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true,
BatchTicker: &mockTicker{firstBobTicker.C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(5 * time.Second).C},
BatchSize: 10,
},
firstBobChannel,
startingHeight,
@ -886,12 +966,14 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
secondBobTicker := time.NewTicker(50 * time.Millisecond)
secondBobChannelLink := NewChannelLink(
ChannelLinkConfig{
FwrdingPolicy: globalPolicy,
Peer: carolServer,
Switch: bobServer.htlcSwitch,
DecodeHopIterator: decoder.DecodeHopIterator,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (ErrorEncrypter,
lnwire.FailCode) {
FwrdingPolicy: globalPolicy,
Peer: carolServer,
Circuits: bobServer.htlcSwitch.CircuitModifier(),
ForwardPackets: bobServer.htlcSwitch.ForwardPackets,
DecodeHopIterator: bobDecoder.DecodeHopIterator,
DecodeHopIterators: bobDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (
ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone
},
GetLastChannelUpdate: mockGetChanUpdateMessage,
@ -902,10 +984,11 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
UpdateContractSignals: func(*contractcourt.ContractSignals) error {
return nil
},
ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true,
BatchTicker: &mockTicker{secondBobTicker.C},
BatchSize: 10,
ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true,
BatchTicker: &mockTicker{secondBobTicker.C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(5 * time.Second).C},
BatchSize: 10,
},
secondBobChannel,
startingHeight,
@ -932,12 +1015,14 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
carolTicker := time.NewTicker(50 * time.Millisecond)
carolChannelLink := NewChannelLink(
ChannelLinkConfig{
FwrdingPolicy: globalPolicy,
Peer: bobServer,
Switch: carolServer.htlcSwitch,
DecodeHopIterator: decoder.DecodeHopIterator,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (ErrorEncrypter,
lnwire.FailCode) {
FwrdingPolicy: globalPolicy,
Peer: bobServer,
Circuits: carolServer.htlcSwitch.CircuitModifier(),
ForwardPackets: carolServer.htlcSwitch.ForwardPackets,
DecodeHopIterator: carolDecoder.DecodeHopIterator,
DecodeHopIterators: carolDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (
ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone
},
GetLastChannelUpdate: mockGetChanUpdateMessage,
@ -948,10 +1033,11 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
UpdateContractSignals: func(*contractcourt.ContractSignals) error {
return nil
},
ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true,
BatchTicker: &mockTicker{carolTicker.C},
BatchSize: 10,
ChainEvents: &contractcourt.ChainEventSubscription{},
SyncStates: true,
BatchTicker: &mockTicker{carolTicker.C},
FwdPkgGCTicker: &mockTicker{time.NewTicker(5 * time.Second).C},
BatchSize: 10,
},
carolChannel,
startingHeight,

@ -1973,6 +1973,179 @@ func testChannelForceClosure(net *lntest.NetworkHarness, t *harnessTest) {
}
}
// testSphinxReplayPersistence verifies that replayed onion packets are rejected
// by a remote peer after a restart. We use a combination of unsafe
// configuration arguments to force Carol to replay the same sphinx packet after
// reconnecting to Dave, and compare the returned failure message with what we
// expect for replayed onion packets.
func testSphinxReplayPersistence(net *lntest.NetworkHarness, t *harnessTest) {
ctxb := context.Background()
timeout := time.Duration(time.Second * 5)
// Open a channel with 100k satoshis between Carol and Dave with Carol being
// the sole funder of the channel.
chanAmt := btcutil.Amount(100000)
// First, we'll create Dave, the receiver, and start him in hodl mode.
dave, err := net.NewNode([]string{"--debughtlc", "--hodlhtlc"})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
// Next, we'll create Carol and establish a channel to from her to
// Dave. Carol is started in both unsafe-replay and unsafe-disconnect,
// which will cause her to replay any pending Adds held in memory upon
// reconnection.
carol, err := net.NewNode([]string{"--unsafe-replay"})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
if err := net.ConnectNodes(ctxb, carol, dave); err != nil {
t.Fatalf("unable to connect carol to dave: %v", err)
}
err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, carol)
if err != nil {
t.Fatalf("unable to send coins to carol: %v", err)
}
ctxt, _ := context.WithTimeout(ctxb, timeout)
chanPoint := openChannelAndAssert(ctxt, t, net, carol,
dave, chanAmt, 0)
assertAmountSent := func(amt btcutil.Amount) {
// Both channels should also have properly accounted from the
// amount that has been sent/received over the channel.
listReq := &lnrpc.ListChannelsRequest{}
carolListChannels, err := carol.ListChannels(ctxb, listReq)
if err != nil {
t.Fatalf("unable to query for alice's channel list: %v", err)
}
carolSatoshisSent := carolListChannels.Channels[0].TotalSatoshisSent
if carolSatoshisSent != int64(amt) {
t.Fatalf("Carol's satoshis sent is incorrect got %v, expected %v",
carolSatoshisSent, amt)
}
daveListChannels, err := dave.ListChannels(ctxb, listReq)
if err != nil {
t.Fatalf("unable to query for Dave's channel list: %v", err)
}
daveSatoshisReceived := daveListChannels.Channels[0].TotalSatoshisReceived
if daveSatoshisReceived != int64(amt) {
t.Fatalf("Dave's satoshis received is incorrect got %v, expected %v",
daveSatoshisReceived, amt)
}
}
// Now that the channel is open, create an invoice for Dave which
// expects a payment of 1000 satoshis from Carol paid via a particular
// preimage.
const paymentAmt = 1000
preimage := bytes.Repeat([]byte("A"), 32)
invoice := &lnrpc.Invoice{
Memo: "testing",
RPreimage: preimage,
Value: paymentAmt,
}
invoiceResp, err := dave.AddInvoice(ctxb, invoice)
if err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
// Wait for Carol to recognize and advertise the new channel generated
// above.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = carol.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil {
t.Fatalf("alice didn't advertise channel before "+
"timeout: %v", err)
}
err = dave.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil {
t.Fatalf("bob didn't advertise channel before "+
"timeout: %v", err)
}
// With the invoice for Dave added, send a payment from Carol paying
// to the above generated invoice.
ctx, cancel := context.WithCancel(ctxb)
defer cancel()
payStream, err := carol.SendPayment(ctx)
if err != nil {
t.Fatalf("unable to open payment stream: %v", err)
}
sendReq := &lnrpc.SendRequest{PaymentRequest: invoiceResp.PaymentRequest}
err = payStream.Send(sendReq)
if err != nil {
t.Fatalf("unable to send payment: %v", err)
}
time.Sleep(200 * time.Millisecond)
// Dave's invoice should not be marked as settled.
payHash := &lnrpc.PaymentHash{
RHash: invoiceResp.RHash,
}
dbInvoice, err := dave.LookupInvoice(ctxb, payHash)
if err != nil {
t.Fatalf("unable to lookup invoice: %v", err)
}
if dbInvoice.Settled {
t.Fatalf("dave's invoice should not be marked as settled: %v",
spew.Sdump(dbInvoice))
}
// With the payment sent but hedl, all balance related stats should not
// have changed.
time.Sleep(time.Millisecond * 200)
assertAmountSent(0)
// With the first payment sent, restart dave to make sure he is
// persisting the information required to detect replayed sphinx
// packets.
if err := net.RestartNode(dave, nil); err != nil {
t.Fatalf("unable to restart dave: %v", err)
}
// Carol should retransmit the Add hedl in her mailbox on startup. Dave
// should not accept the replayed Add, and actually fail back the
// pending payment. Even though he still holds the original settle, if
// he does fail, it is almost certainly caused by the sphinx replay
// protection, as it is the only validation we do in hodl mode.
resp, err := payStream.Recv()
if err != nil {
t.Fatalf("unable to receive payment response: %v", err)
}
// Construct the response we expect after sending a duplicate packet
// that fails due to sphinx replay detection.
replayErr := fmt.Sprintf("unable to route payment to destination: "+
"TemporaryChannelFailure: unable to de-obfuscate onion failure, "+
"htlc with hash(%x): unable to retrieve onion failure",
invoiceResp.RHash)
if resp.PaymentError != replayErr {
t.Fatalf("received payment error: %v", resp.PaymentError)
}
// Since the payment failed, the balance should still be left unaltered.
assertAmountSent(0)
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, carol, chanPoint, false)
// Finally, shutdown the nodes we created for the duration of the tests,
// only leaving the two seed nodes (Alice and Bob) within our test
// network.
if err := net.ShutdownNode(carol); err != nil {
t.Fatalf("unable to shutdown carol: %v", err)
}
if err := net.ShutdownNode(dave); err != nil {
t.Fatalf("unable to shutdown dave: %v", err)
}
}
func testSingleHopInvoice(net *lntest.NetworkHarness, t *harnessTest) {
ctxb := context.Background()
timeout := time.Duration(time.Second * 5)
@ -3284,16 +3457,18 @@ func testRevokedCloseRetribution(net *lntest.NetworkHarness, t *harnessTest) {
var bobChan *lnrpc.ActiveChannel
var predErr error
err = lntest.WaitPredicate(func() bool {
bobChan, err = getBobChanInfo()
bChan, err := getBobChanInfo()
if err != nil {
t.Fatalf("unable to get bob's channel info: %v", err)
}
if bobChan.LocalBalance != 30000 {
if bChan.LocalBalance != 30000 {
predErr = fmt.Errorf("bob's balance is incorrect, "+
"got %v, expected %v", bobChan.LocalBalance,
"got %v, expected %v", bChan.LocalBalance,
30000)
return false
}
bobChan = bChan
return true
}, time.Second*15)
if err != nil {
@ -6448,6 +6623,1116 @@ func testMultiHopHtlcRemoteChainClaim(net *lntest.NetworkHarness, t *harnessTest
}
}
// testSwitchCircuitPersistence creates a multihop network to ensure the sender
// and intermediaries are persisting their open payment circuits. After
// forwarding a packet via an outgoing link, all are restarted, and expected to
// forward a response back from the receiver once back online.
//
// The general flow of this test:
// 1. Carol --> Dave --> Alice --> Bob forward payment
// 2. X X X Bob restart sender and intermediaries
// 3. Carol <-- Dave <-- Alice <-- Bob expect settle to propagate
func testSwitchCircuitPersistence(net *lntest.NetworkHarness, t *harnessTest) {
const chanAmt = btcutil.Amount(1000000)
const pushAmt = btcutil.Amount(900000)
ctxb := context.Background()
timeout := time.Duration(time.Second * 15)
var networkChans []*lnrpc.ChannelPoint
// Open a channel with 100k satoshis between Alice and Bob with Alice
// being the sole funder of the channel.
ctxt, _ := context.WithTimeout(ctxb, timeout)
chanPointAlice := openChannelAndAssert(ctxt, t, net, net.Alice,
net.Bob, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointAlice)
txidHash, err := getChanPointFundingTxid(chanPointAlice)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
aliceChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
aliceFundPoint := wire.OutPoint{
Hash: *aliceChanTXID,
Index: chanPointAlice.OutputIndex,
}
// As preliminary setup, we'll create two new nodes: Carol and Dave,
// such that we now have a 4 ndoe, 3 channel topology. Dave will make
// a channel with Alice, and Carol with Dave. After this setup, the
// network topology should now look like:
// Carol -> Dave -> Alice -> Bob
//
// First, we'll create Dave and establish a channel to Alice.
dave, err := net.NewNode(nil)
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
if err := net.ConnectNodes(ctxb, dave, net.Alice); err != nil {
t.Fatalf("unable to connect dave to alice: %v", err)
}
err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, dave)
if err != nil {
t.Fatalf("unable to send coins to dave: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
chanPointDave := openChannelAndAssert(ctxt, t, net, dave,
net.Alice, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointDave)
txidHash, err = getChanPointFundingTxid(chanPointDave)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
daveChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
daveFundPoint := wire.OutPoint{
Hash: *daveChanTXID,
Index: chanPointDave.OutputIndex,
}
// Next, we'll create Carol and establish a channel to from her to
// Dave. Carol is started in htlchodl mode so that we can disconnect the
// intermediary hops before starting the settle.
carol, err := net.NewNode([]string{"--debughtlc", "--hodlhtlc"})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
if err := net.ConnectNodes(ctxb, carol, dave); err != nil {
t.Fatalf("unable to connect carol to dave: %v", err)
}
err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, carol)
if err != nil {
t.Fatalf("unable to send coins to carol: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
chanPointCarol := openChannelAndAssert(ctxt, t, net, carol,
dave, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointCarol)
txidHash, err = getChanPointFundingTxid(chanPointCarol)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
carolChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
carolFundPoint := wire.OutPoint{
Hash: *carolChanTXID,
Index: chanPointCarol.OutputIndex,
}
// Wait for all nodes to have seen all channels.
nodes := []*lntest.HarnessNode{net.Alice, net.Bob, carol, dave}
nodeNames := []string{"Alice", "Bob", "Carol", "Dave"}
for _, chanPoint := range networkChans {
for i, node := range nodes {
txidHash, err := getChanPointFundingTxid(chanPoint)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
txid, e := chainhash.NewHash(txidHash)
if e != nil {
t.Fatalf("unable to create sha hash: %v", e)
}
point := wire.OutPoint{
Hash: *txid,
Index: chanPoint.OutputIndex,
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = node.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil {
t.Fatalf("%s(%d): timeout waiting for "+
"channel(%s) open: %v", nodeNames[i],
node.NodeID, point, err)
}
}
}
// Create 5 invoices for Carol, which expect a payment from Bob for 1k
// satoshis with a different preimage each time.
const numPayments = 5
const paymentAmt = 1000
payReqs := make([]string, numPayments)
for i := 0; i < numPayments; i++ {
invoice := &lnrpc.Invoice{
Memo: "testing",
Value: paymentAmt,
}
resp, err := carol.AddInvoice(ctxb, invoice)
if err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
payReqs[i] = resp.PaymentRequest
}
// We'll wait for all parties to recognize the new channels within the
// network.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = dave.WaitForNetworkChannelOpen(ctxt, chanPointDave)
if err != nil {
t.Fatalf("dave didn't advertise his channel: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = carol.WaitForNetworkChannelOpen(ctxt, chanPointCarol)
if err != nil {
t.Fatalf("carol didn't advertise her channel in time: %v",
err)
}
time.Sleep(time.Millisecond * 50)
// Using Carol as the source, pay to the 5 invoices from Bob created
// above.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = completePaymentRequests(ctxt, net.Bob, payReqs, false)
if err != nil {
t.Fatalf("unable to send payments: %v", err)
}
time.Sleep(time.Millisecond * 200)
// Restart the intermediaries and the sender.
if err := net.RestartNode(dave, nil); err != nil {
t.Fatalf("Node restart failed: %v", err)
}
if err := net.RestartNode(net.Alice, nil); err != nil {
t.Fatalf("Node restart failed: %v", err)
}
if err := net.RestartNode(net.Bob, nil); err != nil {
t.Fatalf("Node restart failed: %v", err)
}
// Now restart carol without hodl mode, to settle back the outstanding
// payments.
carol.SetExtraArgs(nil)
if err := net.RestartNode(carol, nil); err != nil {
t.Fatalf("Node restart failed: %v", err)
}
time.Sleep(time.Second * 5)
// When asserting the amount of satoshis moved, we'll factor in the
// default base fee, as we didn't modify the fee structure when
// creating the seed nodes in the network.
const baseFee = 1
// At this point all the channels within our proto network should be
// shifted by 5k satoshis in the direction of Carol, the sink within the
// payment flow generated above. The order of asserts corresponds to
// increasing of time is needed to embed the HTLC in commitment
// transaction, in channel Bob->Alice->David->Carol, order is Carol,
// David, Alice, Bob.
var amountPaid = int64(5000)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", carol,
carolFundPoint, int64(0), amountPaid)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", dave,
carolFundPoint, amountPaid, int64(0))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", dave,
daveFundPoint, int64(0), amountPaid+(baseFee*numPayments))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", net.Alice,
daveFundPoint, amountPaid+(baseFee*numPayments), int64(0))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Alice,
aliceFundPoint, int64(0), amountPaid+((baseFee*numPayments)*2))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Bob,
aliceFundPoint, amountPaid+(baseFee*numPayments)*2, int64(0))
// Lastly, we will send one more payment to ensure all channels are
// still functioning properly.
finalInvoice := &lnrpc.Invoice{
Memo: "testing",
Value: paymentAmt,
}
resp, err := carol.AddInvoice(ctxb, finalInvoice)
if err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
payReqs = []string{resp.PaymentRequest}
// Using Carol as the source, pay to the 5 invoices from Bob created
// above.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = completePaymentRequests(ctxt, net.Bob, payReqs, true)
if err != nil {
t.Fatalf("unable to send payments: %v", err)
}
amountPaid = int64(6000)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", carol,
carolFundPoint, int64(0), amountPaid)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", dave,
carolFundPoint, amountPaid, int64(0))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", dave,
daveFundPoint, int64(0), amountPaid+(baseFee*(numPayments+1)))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", net.Alice,
daveFundPoint, amountPaid+(baseFee*(numPayments+1)), int64(0))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Alice,
aliceFundPoint, int64(0), amountPaid+((baseFee*(numPayments+1))*2))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Bob,
aliceFundPoint, amountPaid+(baseFee*(numPayments+1))*2, int64(0))
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, net.Alice, chanPointAlice, false)
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, dave, chanPointDave, false)
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, carol, chanPointCarol, false)
// Finally, shutdown the nodes we created for the duration of the tests,
// only leaving the two seed nodes (Alice and Bob) within our test
// network.
if err := net.ShutdownNode(carol); err != nil {
t.Fatalf("unable to shutdown carol: %v", err)
}
if err := net.ShutdownNode(dave); err != nil {
t.Fatalf("unable to shutdown dave: %v", err)
}
}
// testSwitchOfflineDelivery constructs a set of multihop payments, and tests
// that the returning payments are not lost if a peer on the backwards path is
// offline when the settle/fails are received. We expect the payments to be
// buffered in memory, and transmitted as soon as the disconnect link comes back
// online.
//
// The general flow of this test:
// 1. Carol --> Dave --> Alice --> Bob forward payment
// 2. Carol --- Dave X Alice --- Bob disconnect intermediaries
// 3. Carol --- Dave X Alice <-- Bob settle last hop
// 4. Carol <-- Dave <-- Alice --- Bob reconnect, expect settle to propagate
func testSwitchOfflineDelivery(net *lntest.NetworkHarness, t *harnessTest) {
const chanAmt = btcutil.Amount(1000000)
const pushAmt = btcutil.Amount(900000)
ctxb := context.Background()
timeout := time.Duration(time.Second * 15)
var networkChans []*lnrpc.ChannelPoint
// Open a channel with 100k satoshis between Alice and Bob with Alice
// being the sole funder of the channel.
ctxt, _ := context.WithTimeout(ctxb, timeout)
chanPointAlice := openChannelAndAssert(ctxt, t, net, net.Alice,
net.Bob, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointAlice)
txidHash, err := getChanPointFundingTxid(chanPointAlice)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
aliceChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
aliceFundPoint := wire.OutPoint{
Hash: *aliceChanTXID,
Index: chanPointAlice.OutputIndex,
}
// As preliminary setup, we'll create two new nodes: Carol and Dave,
// such that we now have a 4 ndoe, 3 channel topology. Dave will make
// a channel with Alice, and Carol with Dave. After this setup, the
// network topology should now look like:
// Carol -> Dave -> Alice -> Bob
//
// First, we'll create Dave and establish a channel to Alice.
dave, err := net.NewNode([]string{"--unsafe-disconnect"})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
if err := net.ConnectNodes(ctxb, dave, net.Alice); err != nil {
t.Fatalf("unable to connect dave to alice: %v", err)
}
err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, dave)
if err != nil {
t.Fatalf("unable to send coins to dave: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
chanPointDave := openChannelAndAssert(ctxt, t, net, dave,
net.Alice, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointDave)
txidHash, err = getChanPointFundingTxid(chanPointDave)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
daveChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
daveFundPoint := wire.OutPoint{
Hash: *daveChanTXID,
Index: chanPointDave.OutputIndex,
}
// Next, we'll create Carol and establish a channel to from her to
// Dave. Carol is started in htlchodl mode so that we can disconnect the
// intermediary hops before starting the settle.
carol, err := net.NewNode([]string{"--debughtlc", "--hodlhtlc"})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
if err := net.ConnectNodes(ctxb, carol, dave); err != nil {
t.Fatalf("unable to connect carol to dave: %v", err)
}
err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, carol)
if err != nil {
t.Fatalf("unable to send coins to carol: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
chanPointCarol := openChannelAndAssert(ctxt, t, net, carol,
dave, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointCarol)
txidHash, err = getChanPointFundingTxid(chanPointCarol)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
carolChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
carolFundPoint := wire.OutPoint{
Hash: *carolChanTXID,
Index: chanPointCarol.OutputIndex,
}
// Wait for all nodes to have seen all channels.
nodes := []*lntest.HarnessNode{net.Alice, net.Bob, carol, dave}
nodeNames := []string{"Alice", "Bob", "Carol", "Dave"}
for _, chanPoint := range networkChans {
for i, node := range nodes {
txidHash, err := getChanPointFundingTxid(chanPoint)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
txid, e := chainhash.NewHash(txidHash)
if e != nil {
t.Fatalf("unable to create sha hash: %v", e)
}
point := wire.OutPoint{
Hash: *txid,
Index: chanPoint.OutputIndex,
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = node.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil {
t.Fatalf("%s(%d): timeout waiting for "+
"channel(%s) open: %v", nodeNames[i],
node.NodeID, point, err)
}
}
}
// Create 5 invoices for Carol, which expect a payment from Bob for 1k
// satoshis with a different preimage each time.
const numPayments = 5
const paymentAmt = 1000
payReqs := make([]string, numPayments)
for i := 0; i < numPayments; i++ {
invoice := &lnrpc.Invoice{
Memo: "testing",
Value: paymentAmt,
}
resp, err := carol.AddInvoice(ctxb, invoice)
if err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
payReqs[i] = resp.PaymentRequest
}
// We'll wait for all parties to recognize the new channels within the
// network.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = dave.WaitForNetworkChannelOpen(ctxt, chanPointDave)
if err != nil {
t.Fatalf("dave didn't advertise his channel: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = carol.WaitForNetworkChannelOpen(ctxt, chanPointCarol)
if err != nil {
t.Fatalf("carol didn't advertise her channel in time: %v",
err)
}
time.Sleep(time.Millisecond * 50)
// Using Carol as the source, pay to the 5 invoices from Bob created
// above.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = completePaymentRequests(ctxt, net.Bob, payReqs, false)
if err != nil {
t.Fatalf("unable to send payments: %v", err)
}
time.Sleep(2 * time.Second)
// First, disconnect Dave and Alice so that their link is broken.
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.DisconnectNodes(ctxt, dave, net.Alice); err != nil {
t.Fatalf("unable to disconnect alice from dave: %v", err)
}
// Then, reconnect them to ensure Dave doesn't just fail back the htlc.
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.ConnectNodes(ctxt, dave, net.Alice); err != nil {
t.Fatalf("unable to reconnect alice to dave: %v", err)
}
// Now, disconnect Dave from Alice again before settling back the
// payment.
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.DisconnectNodes(ctxt, dave, net.Alice); err != nil {
t.Fatalf("unable to disconnect alice from dave: %v", err)
}
// Now restart carol without hodl mode, to settle back the outstanding
// payments.
carol.SetExtraArgs(nil)
if err := net.RestartNode(carol, nil); err != nil {
t.Fatalf("Node restart failed: %v", err)
}
time.Sleep(200 * time.Millisecond)
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.ConnectNodes(ctxt, dave, net.Alice); err != nil {
t.Fatalf("unable to reconnect alice to dave: %v", err)
}
time.Sleep(200 * time.Millisecond)
// When asserting the amount of satoshis moved, we'll factor in the
// default base fee, as we didn't modify the fee structure when
// creating the seed nodes in the network.
const baseFee = 1
// At this point all the channels within our proto network should be
// shifted by 5k satoshis in the direction of Carol, the sink within the
// payment flow generated above. The order of asserts corresponds to
// increasing of time is needed to embed the HTLC in commitment
// transaction, in channel Bob->Alice->David->Carol, order is Carol,
// David, Alice, Bob.
var amountPaid = int64(5000)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", carol,
carolFundPoint, int64(0), amountPaid)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", dave,
carolFundPoint, amountPaid, int64(0))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", dave,
daveFundPoint, int64(0), amountPaid+(baseFee*numPayments))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", net.Alice,
daveFundPoint, amountPaid+(baseFee*numPayments), int64(0))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Alice,
aliceFundPoint, int64(0), amountPaid+((baseFee*numPayments)*2))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Bob,
aliceFundPoint, amountPaid+(baseFee*numPayments)*2, int64(0))
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.DisconnectNodes(ctxt, dave, net.Alice); err != nil {
t.Fatalf("unable to disconnect alice from dave: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.ConnectNodes(ctxt, dave, net.Alice); err != nil {
t.Fatalf("unable to reconnect alice to dave: %v", err)
}
// Lastly, we will send one more payment to ensure all channels are
// still functioning properly.
finalInvoice := &lnrpc.Invoice{
Memo: "testing",
Value: paymentAmt,
}
resp, err := carol.AddInvoice(ctxb, finalInvoice)
if err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
payReqs = []string{resp.PaymentRequest}
// Using Carol as the source, pay to the 5 invoices from Bob created
// above.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = completePaymentRequests(ctxt, net.Bob, payReqs, true)
if err != nil {
t.Fatalf("unable to send payments: %v", err)
}
amountPaid = int64(6000)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", carol,
carolFundPoint, int64(0), amountPaid)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", dave,
carolFundPoint, amountPaid, int64(0))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", dave,
daveFundPoint, int64(0), amountPaid+(baseFee*(numPayments+1)))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", net.Alice,
daveFundPoint, amountPaid+(baseFee*(numPayments+1)), int64(0))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Alice,
aliceFundPoint, int64(0), amountPaid+((baseFee*(numPayments+1))*2))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Bob,
aliceFundPoint, amountPaid+(baseFee*(numPayments+1))*2, int64(0))
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, net.Alice, chanPointAlice, false)
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, dave, chanPointDave, false)
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, carol, chanPointCarol, false)
// Finally, shutdown the nodes we created for the duration of the tests,
// only leaving the two seed nodes (Alice and Bob) within our test
// network.
if err := net.ShutdownNode(carol); err != nil {
t.Fatalf("unable to shutdown carol: %v", err)
}
if err := net.ShutdownNode(dave); err != nil {
t.Fatalf("unable to shutdown dave: %v", err)
}
}
// testSwitchOfflineDeliveryPersistence constructs a set of multihop payments,
// and tests that the returning payments are not lost if a peer on the backwards
// path is offline when the settle/fails are received AND the peer buffering the
// responses is completely restarts. We expect the payments to be reloaded from
// disk, and transmitted as soon as the intermediaries are reconnected.
//
// The general flow of this test:
// 1. Carol --> Dave --> Alice --> Bob forward payment
// 2. Carol --- Dave X Alice --- Bob disconnect intermediaries
// 3. Carol --- Dave X Alice <-- Bob settle last hop
// 4. Carol --- Dave X X Bob restart Alice
// 5. Carol <-- Dave <-- Alice --- Bob expect settle to propagate
func testSwitchOfflineDeliveryPersistence(net *lntest.NetworkHarness, t *harnessTest) {
const chanAmt = btcutil.Amount(1000000)
const pushAmt = btcutil.Amount(900000)
ctxb := context.Background()
timeout := time.Duration(time.Second * 15)
var networkChans []*lnrpc.ChannelPoint
// Open a channel with 100k satoshis between Alice and Bob with Alice
// being the sole funder of the channel.
ctxt, _ := context.WithTimeout(ctxb, timeout)
chanPointAlice := openChannelAndAssert(ctxt, t, net, net.Alice,
net.Bob, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointAlice)
txidHash, err := getChanPointFundingTxid(chanPointAlice)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
aliceChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
aliceFundPoint := wire.OutPoint{
Hash: *aliceChanTXID,
Index: chanPointAlice.OutputIndex,
}
// As preliminary setup, we'll create two new nodes: Carol and Dave,
// such that we now have a 4 ndoe, 3 channel topology. Dave will make
// a channel with Alice, and Carol with Dave. After this setup, the
// network topology should now look like:
// Carol -> Dave -> Alice -> Bob
//
// First, we'll create Dave and establish a channel to Alice.
dave, err := net.NewNode([]string{"--unsafe-disconnect"})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
if err := net.ConnectNodes(ctxb, dave, net.Alice); err != nil {
t.Fatalf("unable to connect dave to alice: %v", err)
}
err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, dave)
if err != nil {
t.Fatalf("unable to send coins to dave: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
chanPointDave := openChannelAndAssert(ctxt, t, net, dave,
net.Alice, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointDave)
txidHash, err = getChanPointFundingTxid(chanPointDave)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
daveChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
daveFundPoint := wire.OutPoint{
Hash: *daveChanTXID,
Index: chanPointDave.OutputIndex,
}
// Next, we'll create Carol and establish a channel to from her to
// Dave. Carol is started in htlchodl mode so that we can disconnect the
// intermediary hops before starting the settle.
carol, err := net.NewNode([]string{"--debughtlc", "--hodlhtlc"})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
if err := net.ConnectNodes(ctxb, carol, dave); err != nil {
t.Fatalf("unable to connect carol to dave: %v", err)
}
err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, carol)
if err != nil {
t.Fatalf("unable to send coins to carol: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
chanPointCarol := openChannelAndAssert(ctxt, t, net, carol,
dave, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointCarol)
txidHash, err = getChanPointFundingTxid(chanPointCarol)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
carolChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
carolFundPoint := wire.OutPoint{
Hash: *carolChanTXID,
Index: chanPointCarol.OutputIndex,
}
// Wait for all nodes to have seen all channels.
nodes := []*lntest.HarnessNode{net.Alice, net.Bob, carol, dave}
nodeNames := []string{"Alice", "Bob", "Carol", "Dave"}
for _, chanPoint := range networkChans {
for i, node := range nodes {
txidHash, err := getChanPointFundingTxid(chanPoint)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
txid, e := chainhash.NewHash(txidHash)
if e != nil {
t.Fatalf("unable to create sha hash: %v", e)
}
point := wire.OutPoint{
Hash: *txid,
Index: chanPoint.OutputIndex,
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = node.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil {
t.Fatalf("%s(%d): timeout waiting for "+
"channel(%s) open: %v", nodeNames[i],
node.NodeID, point, err)
}
}
}
// Create 5 invoices for Carol, which expect a payment from Bob for 1k
// satoshis with a different preimage each time.
const numPayments = 5
const paymentAmt = 1000
payReqs := make([]string, numPayments)
for i := 0; i < numPayments; i++ {
invoice := &lnrpc.Invoice{
Memo: "testing",
Value: paymentAmt,
}
resp, err := carol.AddInvoice(ctxb, invoice)
if err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
payReqs[i] = resp.PaymentRequest
}
// We'll wait for all parties to recognize the new channels within the
// network.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = dave.WaitForNetworkChannelOpen(ctxt, chanPointDave)
if err != nil {
t.Fatalf("dave didn't advertise his channel: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = carol.WaitForNetworkChannelOpen(ctxt, chanPointCarol)
if err != nil {
t.Fatalf("carol didn't advertise her channel in time: %v",
err)
}
time.Sleep(time.Millisecond * 50)
// Using Carol as the source, pay to the 5 invoices from Bob created
// above.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = completePaymentRequests(ctxt, net.Bob, payReqs, false)
if err != nil {
t.Fatalf("unable to send payments: %v", err)
}
time.Sleep(2 * time.Second)
// Restart the intermediaries and the sender.
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.DisconnectNodes(ctxt, dave, net.Alice); err != nil {
t.Fatalf("unable to disconnect alice from dave: %v", err)
}
// Now restart carol without hodl mode, to settle back the outstanding
// payments.
carol.SetExtraArgs(nil)
if err := net.RestartNode(carol, nil); err != nil {
t.Fatalf("Node restart failed: %v", err)
}
time.Sleep(200 * time.Millisecond)
if err := net.RestartNode(dave, nil); err != nil {
t.Fatalf("unable to reconnect alice to dave: %v", err)
}
time.Sleep(200 * time.Millisecond)
// When asserting the amount of satoshis moved, we'll factor in the
// default base fee, as we didn't modify the fee structure when
// creating the seed nodes in the network.
const baseFee = 1
// At this point all the channels within our proto network should be
// shifted by 5k satoshis in the direction of Carol, the sink within the
// payment flow generated above. The order of asserts corresponds to
// increasing of time is needed to embed the HTLC in commitment
// transaction, in channel Bob->Alice->David->Carol, order is Carol,
// David, Alice, Bob.
var amountPaid = int64(5000)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", carol,
carolFundPoint, int64(0), amountPaid)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", dave,
carolFundPoint, amountPaid, int64(0))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", dave,
daveFundPoint, int64(0), amountPaid+(baseFee*numPayments))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", net.Alice,
daveFundPoint, amountPaid+(baseFee*numPayments), int64(0))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Alice,
aliceFundPoint, int64(0), amountPaid+((baseFee*numPayments)*2))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Bob,
aliceFundPoint, amountPaid+(baseFee*numPayments)*2, int64(0))
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.DisconnectNodes(ctxt, dave, net.Alice); err != nil {
t.Fatalf("unable to disconnect alice from dave: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.ConnectNodes(ctxt, dave, net.Alice); err != nil {
t.Fatalf("unable to reconnect alice to dave: %v", err)
}
// Lastly, we will send one more payment to ensure all channels are
// still functioning properly.
finalInvoice := &lnrpc.Invoice{
Memo: "testing",
Value: paymentAmt,
}
resp, err := carol.AddInvoice(ctxb, finalInvoice)
if err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
payReqs = []string{resp.PaymentRequest}
// Using Carol as the source, pay to the 5 invoices from Bob created
// above.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = completePaymentRequests(ctxt, net.Bob, payReqs, true)
if err != nil {
t.Fatalf("unable to send payments: %v", err)
}
amountPaid = int64(6000)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", carol,
carolFundPoint, int64(0), amountPaid)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", dave,
carolFundPoint, amountPaid, int64(0))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", dave,
daveFundPoint, int64(0), amountPaid+(baseFee*(numPayments+1)))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", net.Alice,
daveFundPoint, amountPaid+(baseFee*(numPayments+1)), int64(0))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Alice,
aliceFundPoint, int64(0), amountPaid+((baseFee*(numPayments+1))*2))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Bob,
aliceFundPoint, amountPaid+(baseFee*(numPayments+1))*2, int64(0))
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, net.Alice, chanPointAlice, false)
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, dave, chanPointDave, false)
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, carol, chanPointCarol, false)
// Finally, shutdown the nodes we created for the duration of the tests,
// only leaving the two seed nodes (Alice and Bob) within our test
// network.
if err := net.ShutdownNode(carol); err != nil {
t.Fatalf("unable to shutdown carol: %v", err)
}
if err := net.ShutdownNode(dave); err != nil {
t.Fatalf("unable to shutdown dave: %v", err)
}
}
// testSwitchOfflineDeliveryOutgoingOffline constructs a set of multihop payments,
// and tests that the returning payments are not lost if a peer on the backwards
// path is offline when the settle/fails are received AND the peer buffering the
// responses is completely restarts. We expect the payments to be reloaded from
// disk, and transmitted as soon as the intermediaries are reconnected.
//
// The general flow of this test:
// 1. Carol --> Dave --> Alice --> Bob forward payment
// 2. Carol --- Dave X Alice --- Bob disconnect intermediaries
// 3. Carol --- Dave X Alice <-- Bob settle last hop
// 4. Carol --- Dave X X shutdown Bob, restart Alice
// 5. Carol <-- Dave <-- Alice X expect settle to propagate
func testSwitchOfflineDeliveryOutgoingOffline(
net *lntest.NetworkHarness, t *harnessTest) {
const chanAmt = btcutil.Amount(1000000)
const pushAmt = btcutil.Amount(900000)
ctxb := context.Background()
timeout := time.Duration(time.Second * 15)
var networkChans []*lnrpc.ChannelPoint
// Open a channel with 100k satoshis between Alice and Bob with Alice
// being the sole funder of the channel.
ctxt, _ := context.WithTimeout(ctxb, timeout)
chanPointAlice := openChannelAndAssert(ctxt, t, net, net.Alice,
net.Bob, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointAlice)
txidHash, err := getChanPointFundingTxid(chanPointAlice)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
aliceChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
aliceFundPoint := wire.OutPoint{
Hash: *aliceChanTXID,
Index: chanPointAlice.OutputIndex,
}
// As preliminary setup, we'll create two new nodes: Carol and Dave,
// such that we now have a 4 ndoe, 3 channel topology. Dave will make
// a channel with Alice, and Carol with Dave. After this setup, the
// network topology should now look like:
// Carol -> Dave -> Alice -> Bob
//
// First, we'll create Dave and establish a channel to Alice.
dave, err := net.NewNode([]string{"--unsafe-disconnect"})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
if err := net.ConnectNodes(ctxb, dave, net.Alice); err != nil {
t.Fatalf("unable to connect dave to alice: %v", err)
}
err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, dave)
if err != nil {
t.Fatalf("unable to send coins to dave: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
chanPointDave := openChannelAndAssert(ctxt, t, net, dave,
net.Alice, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointDave)
txidHash, err = getChanPointFundingTxid(chanPointDave)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
daveChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
daveFundPoint := wire.OutPoint{
Hash: *daveChanTXID,
Index: chanPointDave.OutputIndex,
}
// Next, we'll create Carol and establish a channel to from her to
// Dave. Carol is started in htlchodl mode so that we can disconnect the
// intermediary hops before starting the settle.
carol, err := net.NewNode([]string{"--debughtlc", "--hodlhtlc"})
if err != nil {
t.Fatalf("unable to create new nodes: %v", err)
}
if err := net.ConnectNodes(ctxb, carol, dave); err != nil {
t.Fatalf("unable to connect carol to dave: %v", err)
}
err = net.SendCoins(ctxb, btcutil.SatoshiPerBitcoin, carol)
if err != nil {
t.Fatalf("unable to send coins to carol: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
chanPointCarol := openChannelAndAssert(ctxt, t, net, carol,
dave, chanAmt, pushAmt)
networkChans = append(networkChans, chanPointCarol)
txidHash, err = getChanPointFundingTxid(chanPointCarol)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
carolChanTXID, err := chainhash.NewHash(txidHash)
if err != nil {
t.Fatalf("unable to create sha hash: %v", err)
}
carolFundPoint := wire.OutPoint{
Hash: *carolChanTXID,
Index: chanPointCarol.OutputIndex,
}
// Wait for all nodes to have seen all channels.
nodes := []*lntest.HarnessNode{net.Alice, net.Bob, carol, dave}
nodeNames := []string{"Alice", "Bob", "Carol", "Dave"}
for _, chanPoint := range networkChans {
for i, node := range nodes {
txidHash, err := getChanPointFundingTxid(chanPoint)
if err != nil {
t.Fatalf("unable to get txid: %v", err)
}
txid, e := chainhash.NewHash(txidHash)
if e != nil {
t.Fatalf("unable to create sha hash: %v", e)
}
point := wire.OutPoint{
Hash: *txid,
Index: chanPoint.OutputIndex,
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = node.WaitForNetworkChannelOpen(ctxt, chanPoint)
if err != nil {
t.Fatalf("%s(%d): timeout waiting for "+
"channel(%s) open: %v", nodeNames[i],
node.NodeID, point, err)
}
}
}
// Create 5 invoices for Carol, which expect a payment from Bob for 1k
// satoshis with a different preimage each time.
const numPayments = 5
const paymentAmt = 1000
payReqs := make([]string, numPayments)
for i := 0; i < numPayments; i++ {
invoice := &lnrpc.Invoice{
Memo: "testing",
Value: paymentAmt,
}
resp, err := carol.AddInvoice(ctxb, invoice)
if err != nil {
t.Fatalf("unable to add invoice: %v", err)
}
payReqs[i] = resp.PaymentRequest
}
// We'll wait for all parties to recognize the new channels within the
// network.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = dave.WaitForNetworkChannelOpen(ctxt, chanPointDave)
if err != nil {
t.Fatalf("dave didn't advertise his channel: %v", err)
}
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = carol.WaitForNetworkChannelOpen(ctxt, chanPointCarol)
if err != nil {
t.Fatalf("carol didn't advertise her channel in time: %v",
err)
}
time.Sleep(time.Millisecond * 50)
// Using Carol as the source, pay to the 5 invoices from Bob created
// above.
ctxt, _ = context.WithTimeout(ctxb, timeout)
err = completePaymentRequests(ctxt, net.Bob, payReqs, false)
if err != nil {
t.Fatalf("unable to send payments: %v", err)
}
time.Sleep(2 * time.Second)
// Restart the intermediaries and the sender.
ctxt, _ = context.WithTimeout(ctxb, timeout)
if err := net.DisconnectNodes(ctxt, dave, net.Alice); err != nil {
t.Fatalf("unable to disconnect alice from dave: %v", err)
}
// Now restart carol without hodl mode, to settle back the outstanding
// payments.
carol.SetExtraArgs(nil)
if err := net.RestartNode(carol, nil); err != nil {
t.Fatalf("Node restart failed: %v", err)
}
time.Sleep(200 * time.Millisecond)
const amountPaid = int64(5000)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", carol,
carolFundPoint, int64(0), amountPaid)
assertAmountPaid(t, ctxb, "Dave(local) => Carol(remote)", dave,
carolFundPoint, amountPaid, int64(0))
if err := net.ShutdownNode(carol); err != nil {
t.Fatalf("unable to shutdown carol: %v", err)
}
if err := net.RestartNode(dave, nil); err != nil {
t.Fatalf("unable to reconnect alice to dave: %v", err)
}
time.Sleep(200 * time.Millisecond)
// When asserting the amount of satoshis moved, we'll factor in the
// default base fee, as we didn't modify the fee structure when
// creating the seed nodes in the network.
const baseFee = 1
// At this point all the channels within our proto network should be
// shifted by 5k satoshis in the direction of Carol, the sink within the
// payment flow generated above. The order of asserts corresponds to
// increasing of time is needed to embed the HTLC in commitment
// transaction, in channel Bob->Alice->David->Carol, order is Carol,
// David, Alice, Bob.
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", dave,
daveFundPoint, int64(0), amountPaid+(baseFee*numPayments))
assertAmountPaid(t, ctxb, "Alice(local) => Dave(remote)", net.Alice,
daveFundPoint, amountPaid+(baseFee*numPayments), int64(0))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Alice,
aliceFundPoint, int64(0), amountPaid+((baseFee*numPayments)*2))
assertAmountPaid(t, ctxb, "Bob(local) => Alice(remote)", net.Bob,
aliceFundPoint, amountPaid+(baseFee*numPayments)*2, int64(0))
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, net.Alice, chanPointAlice, false)
ctxt, _ = context.WithTimeout(ctxb, timeout)
closeChannelAndAssert(ctxt, t, net, dave, chanPointDave, false)
// Finally, shutdown the nodes we created for the duration of the tests,
// only leaving the two seed nodes (Alice and Bob) within our test
// network.
if err := net.ShutdownNode(dave); err != nil {
t.Fatalf("unable to shutdown dave: %v", err)
}
}
type testCase struct {
name string
test func(net *lntest.NetworkHarness, t *harnessTest)
@ -6490,6 +7775,10 @@ var testsCases = []*testCase{
name: "single hop invoice",
test: testSingleHopInvoice,
},
{
name: "sphinx replay persistence",
test: testSphinxReplayPersistence,
},
{
name: "list outgoing payments",
test: testListPayments,
@ -6587,6 +7876,22 @@ var testsCases = []*testCase{
name: "revoked uncooperative close retribution remote hodl",
test: testRevokedCloseRetributionRemoteHodl,
},
{
name: "switch circuit persistence",
test: testSwitchCircuitPersistence,
},
{
name: "switch offline delivery",
test: testSwitchOfflineDelivery,
},
{
name: "switch offline delivery persistence",
test: testSwitchOfflineDeliveryPersistence,
},
{
name: "switch offline delivery outgoing offline",
test: testSwitchOfflineDeliveryOutgoingOffline,
},
}
// TestLightningNetworkDaemon performs a series of integration tests amongst a

@ -239,6 +239,7 @@ func (hn *HarnessNode) start(lndError chan<- error) error {
hn.quit = make(chan struct{})
args := hn.cfg.genArgs()
args = append(args, fmt.Sprintf("--profile=%d", 9000+hn.NodeID))
hn.cmd = exec.Command("lnd", args...)
// Redirect stderr output to buffer
@ -394,6 +395,12 @@ func (hn *HarnessNode) connectRPC() (*grpc.ClientConn, error) {
return grpc.Dial(hn.cfg.RPCAddr(), opts...)
}
// SetExtraArgs assigns the ExtraArgs field for the node's configuration. The
// changes will take effect on restart.
func (hn *HarnessNode) SetExtraArgs(extraArgs []string) {
hn.cfg.ExtraArgs = extraArgs
}
// cleanup cleans up all the temporary files created by the node's process.
func (hn *HarnessNode) cleanup() error {
return os.RemoveAll(hn.cfg.BaseDir)

53
peer.go

@ -379,18 +379,21 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error {
linkCfg := htlcswitch.ChannelLinkConfig{
Peer: p,
DecodeHopIterator: p.server.sphinx.DecodeHopIterator,
DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
DecodeOnionObfuscator: p.server.sphinx.ExtractErrorEncrypter,
GetLastChannelUpdate: createGetLastUpdate(p.server.chanRouter,
p.PubKey(), lnChan.ShortChanID()),
DebugHTLC: cfg.DebugHTLC,
HodlHTLC: cfg.HodlHTLC,
Registry: p.server.invoices,
Switch: p.server.htlcSwitch,
FwrdingPolicy: *forwardingPolicy,
FeeEstimator: p.server.cc.feeEstimator,
BlockEpochs: blockEpoch,
PreimageCache: p.server.witnessBeacon,
ChainEvents: chainEvents,
DebugHTLC: cfg.DebugHTLC,
HodlHTLC: cfg.HodlHTLC,
Registry: p.server.invoices,
Switch: p.server.htlcSwitch,
Circuits: p.server.htlcSwitch.CircuitModifier(),
ForwardPackets: p.server.htlcSwitch.ForwardPackets,
FwrdingPolicy: *forwardingPolicy,
FeeEstimator: p.server.cc.feeEstimator,
BlockEpochs: blockEpoch,
PreimageCache: p.server.witnessBeacon,
ChainEvents: chainEvents,
UpdateContractSignals: func(signals *contractcourt.ContractSignals) error {
return p.server.chainArb.UpdateContractSignals(
*chanPoint, signals,
@ -399,7 +402,10 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error {
SyncStates: true,
BatchTicker: htlcswitch.NewBatchTicker(
time.NewTicker(50 * time.Millisecond)),
BatchSize: 10,
FwdPkgGCTicker: htlcswitch.NewBatchTicker(
time.NewTicker(time.Minute)),
BatchSize: 10,
UnsafeReplay: cfg.UnsafeReplay,
}
link := htlcswitch.NewChannelLink(linkCfg, lnChan,
uint32(currentHeight))
@ -1020,6 +1026,7 @@ func (p *peer) writeMessage(msg lnwire.Message) error {
// NOTE: This method MUST be run as a goroutine.
func (p *peer) writeHandler() {
var exitErr error
out:
for {
select {
@ -1272,18 +1279,21 @@ out:
linkConfig := htlcswitch.ChannelLinkConfig{
Peer: p,
DecodeHopIterator: p.server.sphinx.DecodeHopIterator,
DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
DecodeOnionObfuscator: p.server.sphinx.ExtractErrorEncrypter,
GetLastChannelUpdate: createGetLastUpdate(p.server.chanRouter,
p.PubKey(), newChanReq.channel.ShortChanID()),
DebugHTLC: cfg.DebugHTLC,
HodlHTLC: cfg.HodlHTLC,
Registry: p.server.invoices,
Switch: p.server.htlcSwitch,
FwrdingPolicy: p.server.cc.routingPolicy,
FeeEstimator: p.server.cc.feeEstimator,
BlockEpochs: blockEpoch,
PreimageCache: p.server.witnessBeacon,
ChainEvents: chainEvents,
DebugHTLC: cfg.DebugHTLC,
HodlHTLC: cfg.HodlHTLC,
Registry: p.server.invoices,
Switch: p.server.htlcSwitch,
Circuits: p.server.htlcSwitch.CircuitModifier(),
ForwardPackets: p.server.htlcSwitch.ForwardPackets,
FwrdingPolicy: p.server.cc.routingPolicy,
FeeEstimator: p.server.cc.feeEstimator,
BlockEpochs: blockEpoch,
PreimageCache: p.server.witnessBeacon,
ChainEvents: chainEvents,
UpdateContractSignals: func(signals *contractcourt.ContractSignals) error {
return p.server.chainArb.UpdateContractSignals(
*chanPoint, signals,
@ -1292,7 +1302,10 @@ out:
SyncStates: false,
BatchTicker: htlcswitch.NewBatchTicker(
time.NewTicker(50 * time.Millisecond)),
BatchSize: 10,
FwdPkgGCTicker: htlcswitch.NewBatchTicker(
time.NewTicker(time.Minute)),
BatchSize: 10,
UnsafeReplay: cfg.UnsafeReplay,
}
link := htlcswitch.NewChannelLink(linkConfig, newChan,
uint32(currentHeight))

@ -650,7 +650,7 @@ func (r *rpcServer) DisconnectPeer(ctx context.Context,
// In order to avoid erroneously disconnecting from a peer that we have
// an active channel with, if we have any channels active with this
// peer, then we'll disallow disconnecting from them.
if len(nodeChannels) > 0 {
if len(nodeChannels) > 0 && !cfg.UnsafeDisconnect {
return nil, fmt.Errorf("cannot disconnect from peer(%x), "+
"all active channels with the peer need to be closed "+
"first", pubKeyBytes)

@ -206,7 +206,8 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
debugPre[:], debugHash[:])
}
s.htlcSwitch = htlcswitch.New(htlcswitch.Config{
htlcSwitch, err := htlcswitch.New(htlcswitch.Config{
DB: chanDB,
SelfKey: s.identityPriv.PubKey(),
LocalChannelClose: func(pubKey []byte,
request *htlcswitch.ChanClose) {
@ -230,8 +231,13 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
pubKey[:], err)
}
},
FwdingLog: chanDB.ForwardingLog(),
FwdingLog: chanDB.ForwardingLog(),
SwitchPackager: channeldb.NewSwitchPackager(),
})
if err != nil {
return nil, err
}
s.htlcSwitch = htlcSwitch
// If external IP addresses have been specified, add those to the list
// of this server's addresses. We need to use the cfg.net.ResolveTCPAddr

@ -2,6 +2,9 @@ package main
import (
"bytes"
crand "crypto/rand"
"encoding/binary"
"io"
"io/ioutil"
"math/rand"
"net"
@ -191,11 +194,21 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
CommitSig: bytes.Repeat([]byte{1}, 71),
}
var chanIDBytes [8]byte
if _, err := io.ReadFull(crand.Reader, chanIDBytes[:]); err != nil {
return nil, nil, nil, nil, err
}
shortChanID := lnwire.NewShortChanIDFromInt(
binary.BigEndian.Uint64(chanIDBytes[:]),
)
aliceChannelState := &channeldb.OpenChannel{
LocalChanCfg: aliceCfg,
RemoteChanCfg: bobCfg,
IdentityPub: aliceKeyPub,
FundingOutpoint: *prevOut,
ShortChanID: shortChanID,
ChanType: channeldb.SingleFunder,
IsInitiator: true,
Capacity: channelCapacity,
@ -205,6 +218,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
LocalCommitment: aliceCommit,
RemoteCommitment: aliceCommit,
Db: dbAlice,
Packager: channeldb.NewChannelPackager(shortChanID),
}
bobChannelState := &channeldb.OpenChannel{
LocalChanCfg: bobCfg,
@ -220,6 +234,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
LocalCommitment: bobCommit,
RemoteCommitment: bobCommit,
Db: dbBob,
Packager: channeldb.NewChannelPackager(shortChanID),
}
addr := &net.TCPAddr{
@ -291,7 +306,14 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
breachArbiter: breachArbiter,
chainArb: chainArb,
}
s.htlcSwitch = htlcswitch.New(htlcswitch.Config{})
htlcSwitch, err := htlcswitch.New(htlcswitch.Config{
DB: dbAlice,
SwitchPackager: channeldb.NewSwitchPackager(),
})
if err != nil {
return nil, nil, nil, nil, err
}
s.htlcSwitch = htlcSwitch
s.htlcSwitch.Start()
alicePeer := &peer{