htlcswitch+server: ensure we always send an update w/ a TempChannelFailure

In this commit, we ensure that any time we send a TempChannelFailure
that's destined for a multi-hop source sender, then we'll always package
the latest channel update along with it.
This commit is contained in:
Olaoluwa Osuntokun 2018-05-07 20:00:32 -07:00 committed by Wilmer Paulino
parent 27ca61aedf
commit 72f48b6abe
No known key found for this signature in database
GPG Key ID: 6DF57B9F9514972F
6 changed files with 101 additions and 40 deletions

View File

@ -982,10 +982,20 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
reason lnwire.OpaqueReason reason lnwire.OpaqueReason
) )
failure := lnwire.NewTemporaryChannelFailure(nil) var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate(
l.ShortChanID(),
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(
update,
)
}
// Encrypt the error back to the source unless the payment was // Encrypt the error back to the source unless
// generated locally. // the payment was generated locally.
if pkt.obfuscator == nil { if pkt.obfuscator == nil {
var b bytes.Buffer var b bytes.Buffer
err := lnwire.EncodeFailure(&b, failure, 0) err := lnwire.EncodeFailure(&b, failure, 0)
@ -1652,11 +1662,9 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte,
// As part of the returned error, we'll send our latest routing // As part of the returned error, we'll send our latest routing
// policy so the sending node obtains the most up to date data. // policy so the sending node obtains the most up to date data.
var failure lnwire.FailureMessage var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate( update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID())
l.shortChanID,
)
if err != nil { if err != nil {
failure = lnwire.NewTemporaryChannelFailure(nil) failure = &lnwire.FailTemporaryNodeFailure{}
} else { } else {
failure = lnwire.NewAmountBelowMinimum( failure = lnwire.NewAmountBelowMinimum(
amtToForward, *update, amtToForward, *update,
@ -1686,11 +1694,9 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte,
// As part of the returned error, we'll send our latest routing // As part of the returned error, we'll send our latest routing
// policy so the sending node obtains the most up to date data. // policy so the sending node obtains the most up to date data.
var failure lnwire.FailureMessage var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate( update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID())
l.shortChanID,
)
if err != nil { if err != nil {
failure = lnwire.NewTemporaryChannelFailure(nil) failure = &lnwire.FailTemporaryNodeFailure{}
} else { } else {
failure = lnwire.NewFeeInsufficient( failure = lnwire.NewFeeInsufficient(
amtToForward, *update, amtToForward, *update,
@ -2242,10 +2248,12 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
var failure lnwire.FailureMessage var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate( update, err := l.cfg.FetchLastChannelUpdate(
l.shortChanID, l.ShortChanID(),
) )
if err != nil { if err != nil {
failure = lnwire.NewTemporaryChannelFailure(nil) failure = lnwire.NewTemporaryChannelFailure(
update,
)
} else { } else {
failure = lnwire.NewExpiryTooSoon(*update) failure = lnwire.NewExpiryTooSoon(*update)
} }
@ -2275,7 +2283,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
// sending node is up to date with our current // sending node is up to date with our current
// policy. // policy.
update, err := l.cfg.FetchLastChannelUpdate( update, err := l.cfg.FetchLastChannelUpdate(
l.shortChanID, l.ShortChanID(),
) )
if err != nil { if err != nil {
l.fail("unable to create channel update "+ l.fail("unable to create channel update "+
@ -2313,7 +2321,17 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
log.Errorf("unable to encode the "+ log.Errorf("unable to encode the "+
"remaining route %v", err) "remaining route %v", err)
failure := lnwire.NewTemporaryChannelFailure(nil) var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate(
l.ShortChanID(),
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(
update,
)
}
l.sendHTLCError( l.sendHTLCError(
pd.HtlcIndex, failure, obfuscator, pd.SourceRef, pd.HtlcIndex, failure, obfuscator, pd.SourceRef,

View File

@ -1459,8 +1459,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
} }
aliceDb := aliceChannel.State().Db aliceDb := aliceChannel.State().Db
aliceSwitch, err := initSwitchWithDB(aliceDb)
aliceSwitch, err := New(Config{DB: aliceDb})
if err != nil { if err != nil {
return nil, nil, nil, nil, nil, err return nil, nil, nil, nil, nil, err
} }
@ -3854,7 +3853,7 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
if aliceSwitch == nil { if aliceSwitch == nil {
var err error var err error
aliceSwitch, err = New(Config{DB: aliceDb}) aliceSwitch, err = initSwitchWithDB(aliceDb)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }

View File

@ -140,6 +140,9 @@ func initSwitchWithDB(db *channeldb.DB) (*Switch, error) {
FwdingLog: &mockForwardingLog{ FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent), events: make(map[time.Time]channeldb.ForwardingEvent),
}, },
FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
return nil, nil
},
}) })
} }

View File

@ -135,6 +135,13 @@ type Config struct {
// error encrypters stored in the circuit map on restarts, since they // error encrypters stored in the circuit map on restarts, since they
// are not stored directly within the database. // are not stored directly within the database.
ExtractErrorEncrypter ErrorEncrypterExtracter ExtractErrorEncrypter ErrorEncrypterExtracter
// FetchLastChannelUpdate retrieves the latest routing policy for a
// target channel. This channel will typically be the outgoing channel
// specified when we receive an incoming HTLC. This will be used to
// provide payment senders our latest policy when sending encrypted
// error messages.
FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error)
} }
// Switch is the central messaging bus for all incoming/outgoing HTLCs. // Switch is the central messaging bus for all incoming/outgoing HTLCs.
@ -458,7 +465,15 @@ func (s *Switch) forward(packet *htlcPacket) error {
return err return err
} }
failure := lnwire.NewTemporaryChannelFailure(nil) var failure lnwire.FailureMessage
update, err := s.cfg.FetchLastChannelUpdate(
packet.incomingChanID,
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(update)
}
addErr := ErrIncompleteForward addErr := ErrIncompleteForward
return s.failAddPacket(packet, failure, addErr) return s.failAddPacket(packet, failure, addErr)
@ -588,14 +603,25 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
// Lastly, for any packets that failed, this implies that they were // Lastly, for any packets that failed, this implies that they were
// left in a half added state, which can happen when recovering from // left in a half added state, which can happen when recovering from
// failures. // failures.
for _, packet := range failedPackets { if len(failedPackets) > 0 {
failure := lnwire.NewTemporaryChannelFailure(nil) var failure lnwire.FailureMessage
addErr := errors.Errorf("failing packet after detecting " + update, err := s.cfg.FetchLastChannelUpdate(
"incomplete forward") failedPackets[0].incomingChanID,
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(update)
}
// We don't handle the error here since this method always for _, packet := range failedPackets {
// returns an error. addErr := errors.Errorf("failing packet after " +
s.failAddPacket(packet, failure, addErr) "detecting incomplete forward")
// We don't handle the error here since this method
// always returns an error.
s.failAddPacket(packet, failure, addErr)
}
} }
return errChan return errChan
@ -749,6 +775,8 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
htlc.Amount, largestBandwidth) htlc.Amount, largestBandwidth)
log.Error(err) log.Error(err)
// Note that we don't need to populate an update here,
// as this will go directly back to the router.
htlcErr := lnwire.NewTemporaryChannelFailure(nil) htlcErr := lnwire.NewTemporaryChannelFailure(nil)
return &ForwardingError{ return &ForwardingError{
ErrorSource: s.cfg.SelfKey, ErrorSource: s.cfg.SelfKey,
@ -812,6 +840,10 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket,
userErr = fmt.Sprintf("unable to decode onion failure, "+ userErr = fmt.Sprintf("unable to decode onion failure, "+
"htlc with hash(%x): %v", payment.paymentHash[:], err) "htlc with hash(%x): %v", payment.paymentHash[:], err)
log.Error(userErr) log.Error(userErr)
// As this didn't even clear the link, we don't need to
// apply an update here since it goes directly to the
// router.
failureMsg = lnwire.NewTemporaryChannelFailure(nil) failureMsg = lnwire.NewTemporaryChannelFailure(nil)
} }
failure = &ForwardingError{ failure = &ForwardingError{
@ -938,7 +970,16 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
// If packet was forwarded from another channel link // If packet was forwarded from another channel link
// than we should notify this link that some error // than we should notify this link that some error
// occurred. // occurred.
failure := lnwire.NewTemporaryChannelFailure(nil) var failure lnwire.FailureMessage
update, err := s.cfg.FetchLastChannelUpdate(
packet.outgoingChanID,
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(update)
}
addErr := errors.Errorf("unable to find appropriate "+ addErr := errors.Errorf("unable to find appropriate "+
"channel link insufficient capacity, need "+ "channel link insufficient capacity, need "+
"%v", htlc.Amount) "%v", htlc.Amount)
@ -1799,7 +1840,7 @@ func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID,
s.indexMtx.Lock() s.indexMtx.Lock()
// First, we'll extract the current link as is from the link // First, we'll extract the current link as is from the link
// index. If the link isn't even in the index, then we'll return an // index. If the link isn't even in the index, then we'll return an
// error. // error.
link, ok := s.linkIndex[chanID] link, ok := s.linkIndex[chanID]

View File

@ -21,7 +21,6 @@ import (
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing"
"github.com/roasbeef/btcd/chaincfg/chainhash" "github.com/roasbeef/btcd/chaincfg/chainhash"
"github.com/roasbeef/btcd/connmgr" "github.com/roasbeef/btcd/connmgr"
"github.com/roasbeef/btcd/txscript" "github.com/roasbeef/btcd/txscript"
@ -414,7 +413,7 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error {
DecodeHopIterators: p.server.sphinx.DecodeHopIterators, DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter, ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter,
FetchLastChannelUpdate: fetchLastChanUpdate( FetchLastChannelUpdate: fetchLastChanUpdate(
p.server.chanRouter, p.PubKey(), p.server, p.PubKey(),
), ),
DebugHTLC: cfg.DebugHTLC, DebugHTLC: cfg.DebugHTLC,
HodlMask: cfg.Hodl.Mask(), HodlMask: cfg.Hodl.Mask(),
@ -1392,7 +1391,7 @@ out:
DecodeHopIterators: p.server.sphinx.DecodeHopIterators, DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter, ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter,
FetchLastChannelUpdate: fetchLastChanUpdate( FetchLastChannelUpdate: fetchLastChanUpdate(
p.server.chanRouter, p.PubKey(), p.server, p.PubKey(),
), ),
DebugHTLC: cfg.DebugHTLC, DebugHTLC: cfg.DebugHTLC,
HodlMask: cfg.Hodl.Mask(), HodlMask: cfg.Hodl.Mask(),
@ -1911,11 +1910,11 @@ func (p *peer) PubKey() [33]byte {
// fetchLastChanUpdate returns a function which is able to retrieve the last // fetchLastChanUpdate returns a function which is able to retrieve the last
// channel update for a target channel. // channel update for a target channel.
func fetchLastChanUpdate(router *routing.ChannelRouter, func fetchLastChanUpdate(s *server,
pubKey [33]byte) func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { pubKey [33]byte) func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
info, edge1, edge2, err := router.GetChannelByID(cid) info, edge1, edge2, err := s.chanRouter.GetChannelByID(cid)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -153,7 +153,8 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
globalFeatures := lnwire.NewRawFeatureVector() globalFeatures := lnwire.NewRawFeatureVector()
serializedPubKey := privKey.PubKey().SerializeCompressed() var serializedPubKey [33]byte
copy(serializedPubKey[:], privKey.PubKey().SerializeCompressed())
// Initialize the sphinx router, placing it's persistent replay log in // Initialize the sphinx router, placing it's persistent replay log in
// the same directory as the channel graph database. // the same directory as the channel graph database.
@ -175,7 +176,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
// TODO(roasbeef): derive proper onion key based on rotation // TODO(roasbeef): derive proper onion key based on rotation
// schedule // schedule
sphinx: htlcswitch.NewOnionProcessor(sphinxRouter), sphinx: htlcswitch.NewOnionProcessor(sphinxRouter),
lightningID: sha256.Sum256(serializedPubKey), lightningID: sha256.Sum256(serializedPubKey[:]),
persistentPeers: make(map[string]struct{}), persistentPeers: make(map[string]struct{}),
persistentPeersBackoff: make(map[string]time.Duration), persistentPeersBackoff: make(map[string]time.Duration),
@ -209,7 +210,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
debugPre[:], debugHash[:]) debugPre[:], debugHash[:])
} }
htlcSwitch, err := htlcswitch.New(htlcswitch.Config{ s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{
DB: chanDB, DB: chanDB,
SelfKey: s.identityPriv.PubKey(), SelfKey: s.identityPriv.PubKey(),
LocalChannelClose: func(pubKey []byte, LocalChannelClose: func(pubKey []byte,
@ -234,14 +235,14 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
pubKey[:], err) pubKey[:], err)
} }
}, },
FwdingLog: chanDB.ForwardingLog(), FwdingLog: chanDB.ForwardingLog(),
SwitchPackager: channeldb.NewSwitchPackager(), SwitchPackager: channeldb.NewSwitchPackager(),
ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter, ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter,
FetchLastChannelUpdate: fetchLastChanUpdate(s, serializedPubKey),
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.htlcSwitch = htlcSwitch
// If external IP addresses have been specified, add those to the list // If external IP addresses have been specified, add those to the list
// of this server's addresses. We need to use the cfg.net.ResolveTCPAddr // of this server's addresses. We need to use the cfg.net.ResolveTCPAddr