htlcswitch+server: ensure we always send an update w/ a TempChannelFailure

In this commit, we ensure that any time we send a TempChannelFailure
that's destined for a multi-hop source sender, then we'll always package
the latest channel update along with it.
This commit is contained in:
Olaoluwa Osuntokun 2018-05-07 20:00:32 -07:00 committed by Wilmer Paulino
parent 27ca61aedf
commit 72f48b6abe
No known key found for this signature in database
GPG Key ID: 6DF57B9F9514972F
6 changed files with 101 additions and 40 deletions

View File

@ -982,10 +982,20 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
reason lnwire.OpaqueReason
)
failure := lnwire.NewTemporaryChannelFailure(nil)
var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate(
l.ShortChanID(),
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(
update,
)
}
// Encrypt the error back to the source unless the payment was
// generated locally.
// Encrypt the error back to the source unless
// the payment was generated locally.
if pkt.obfuscator == nil {
var b bytes.Buffer
err := lnwire.EncodeFailure(&b, failure, 0)
@ -1652,11 +1662,9 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte,
// As part of the returned error, we'll send our latest routing
// policy so the sending node obtains the most up to date data.
var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate(
l.shortChanID,
)
update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID())
if err != nil {
failure = lnwire.NewTemporaryChannelFailure(nil)
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewAmountBelowMinimum(
amtToForward, *update,
@ -1686,11 +1694,9 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte,
// As part of the returned error, we'll send our latest routing
// policy so the sending node obtains the most up to date data.
var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate(
l.shortChanID,
)
update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID())
if err != nil {
failure = lnwire.NewTemporaryChannelFailure(nil)
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewFeeInsufficient(
amtToForward, *update,
@ -2242,10 +2248,12 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate(
l.shortChanID,
l.ShortChanID(),
)
if err != nil {
failure = lnwire.NewTemporaryChannelFailure(nil)
failure = lnwire.NewTemporaryChannelFailure(
update,
)
} else {
failure = lnwire.NewExpiryTooSoon(*update)
}
@ -2275,7 +2283,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
// sending node is up to date with our current
// policy.
update, err := l.cfg.FetchLastChannelUpdate(
l.shortChanID,
l.ShortChanID(),
)
if err != nil {
l.fail("unable to create channel update "+
@ -2313,7 +2321,17 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
log.Errorf("unable to encode the "+
"remaining route %v", err)
failure := lnwire.NewTemporaryChannelFailure(nil)
var failure lnwire.FailureMessage
update, err := l.cfg.FetchLastChannelUpdate(
l.ShortChanID(),
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(
update,
)
}
l.sendHTLCError(
pd.HtlcIndex, failure, obfuscator, pd.SourceRef,

View File

@ -1459,8 +1459,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
}
aliceDb := aliceChannel.State().Db
aliceSwitch, err := New(Config{DB: aliceDb})
aliceSwitch, err := initSwitchWithDB(aliceDb)
if err != nil {
return nil, nil, nil, nil, nil, err
}
@ -3854,7 +3853,7 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
if aliceSwitch == nil {
var err error
aliceSwitch, err = New(Config{DB: aliceDb})
aliceSwitch, err = initSwitchWithDB(aliceDb)
if err != nil {
return nil, nil, nil, err
}

View File

@ -140,6 +140,9 @@ func initSwitchWithDB(db *channeldb.DB) (*Switch, error) {
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
return nil, nil
},
})
}

View File

@ -135,6 +135,13 @@ type Config struct {
// error encrypters stored in the circuit map on restarts, since they
// are not stored directly within the database.
ExtractErrorEncrypter ErrorEncrypterExtracter
// FetchLastChannelUpdate retrieves the latest routing policy for a
// target channel. This channel will typically be the outgoing channel
// specified when we receive an incoming HTLC. This will be used to
// provide payment senders our latest policy when sending encrypted
// error messages.
FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error)
}
// Switch is the central messaging bus for all incoming/outgoing HTLCs.
@ -458,7 +465,15 @@ func (s *Switch) forward(packet *htlcPacket) error {
return err
}
failure := lnwire.NewTemporaryChannelFailure(nil)
var failure lnwire.FailureMessage
update, err := s.cfg.FetchLastChannelUpdate(
packet.incomingChanID,
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(update)
}
addErr := ErrIncompleteForward
return s.failAddPacket(packet, failure, addErr)
@ -588,14 +603,25 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
// Lastly, for any packets that failed, this implies that they were
// left in a half added state, which can happen when recovering from
// failures.
for _, packet := range failedPackets {
failure := lnwire.NewTemporaryChannelFailure(nil)
addErr := errors.Errorf("failing packet after detecting " +
"incomplete forward")
if len(failedPackets) > 0 {
var failure lnwire.FailureMessage
update, err := s.cfg.FetchLastChannelUpdate(
failedPackets[0].incomingChanID,
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(update)
}
// We don't handle the error here since this method always
// returns an error.
s.failAddPacket(packet, failure, addErr)
for _, packet := range failedPackets {
addErr := errors.Errorf("failing packet after " +
"detecting incomplete forward")
// We don't handle the error here since this method
// always returns an error.
s.failAddPacket(packet, failure, addErr)
}
}
return errChan
@ -749,6 +775,8 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
htlc.Amount, largestBandwidth)
log.Error(err)
// Note that we don't need to populate an update here,
// as this will go directly back to the router.
htlcErr := lnwire.NewTemporaryChannelFailure(nil)
return &ForwardingError{
ErrorSource: s.cfg.SelfKey,
@ -812,6 +840,10 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket,
userErr = fmt.Sprintf("unable to decode onion failure, "+
"htlc with hash(%x): %v", payment.paymentHash[:], err)
log.Error(userErr)
// As this didn't even clear the link, we don't need to
// apply an update here since it goes directly to the
// router.
failureMsg = lnwire.NewTemporaryChannelFailure(nil)
}
failure = &ForwardingError{
@ -938,7 +970,16 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
// If packet was forwarded from another channel link
// than we should notify this link that some error
// occurred.
failure := lnwire.NewTemporaryChannelFailure(nil)
var failure lnwire.FailureMessage
update, err := s.cfg.FetchLastChannelUpdate(
packet.outgoingChanID,
)
if err != nil {
failure = &lnwire.FailTemporaryNodeFailure{}
} else {
failure = lnwire.NewTemporaryChannelFailure(update)
}
addErr := errors.Errorf("unable to find appropriate "+
"channel link insufficient capacity, need "+
"%v", htlc.Amount)
@ -1799,7 +1840,7 @@ func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID,
s.indexMtx.Lock()
// First, we'll extract the current link as is from the link
// First, we'll extract the current link as is from the link
// index. If the link isn't even in the index, then we'll return an
// error.
link, ok := s.linkIndex[chanID]

View File

@ -21,7 +21,6 @@ import (
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing"
"github.com/roasbeef/btcd/chaincfg/chainhash"
"github.com/roasbeef/btcd/connmgr"
"github.com/roasbeef/btcd/txscript"
@ -414,7 +413,7 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error {
DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter,
FetchLastChannelUpdate: fetchLastChanUpdate(
p.server.chanRouter, p.PubKey(),
p.server, p.PubKey(),
),
DebugHTLC: cfg.DebugHTLC,
HodlMask: cfg.Hodl.Mask(),
@ -1392,7 +1391,7 @@ out:
DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter,
FetchLastChannelUpdate: fetchLastChanUpdate(
p.server.chanRouter, p.PubKey(),
p.server, p.PubKey(),
),
DebugHTLC: cfg.DebugHTLC,
HodlMask: cfg.Hodl.Mask(),
@ -1911,11 +1910,11 @@ func (p *peer) PubKey() [33]byte {
// fetchLastChanUpdate returns a function which is able to retrieve the last
// channel update for a target channel.
func fetchLastChanUpdate(router *routing.ChannelRouter,
func fetchLastChanUpdate(s *server,
pubKey [33]byte) func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
info, edge1, edge2, err := router.GetChannelByID(cid)
info, edge1, edge2, err := s.chanRouter.GetChannelByID(cid)
if err != nil {
return nil, err
}

View File

@ -153,7 +153,8 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
globalFeatures := lnwire.NewRawFeatureVector()
serializedPubKey := privKey.PubKey().SerializeCompressed()
var serializedPubKey [33]byte
copy(serializedPubKey[:], privKey.PubKey().SerializeCompressed())
// Initialize the sphinx router, placing it's persistent replay log in
// the same directory as the channel graph database.
@ -175,7 +176,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
// TODO(roasbeef): derive proper onion key based on rotation
// schedule
sphinx: htlcswitch.NewOnionProcessor(sphinxRouter),
lightningID: sha256.Sum256(serializedPubKey),
lightningID: sha256.Sum256(serializedPubKey[:]),
persistentPeers: make(map[string]struct{}),
persistentPeersBackoff: make(map[string]time.Duration),
@ -209,7 +210,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
debugPre[:], debugHash[:])
}
htlcSwitch, err := htlcswitch.New(htlcswitch.Config{
s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{
DB: chanDB,
SelfKey: s.identityPriv.PubKey(),
LocalChannelClose: func(pubKey []byte,
@ -234,14 +235,14 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
pubKey[:], err)
}
},
FwdingLog: chanDB.ForwardingLog(),
SwitchPackager: channeldb.NewSwitchPackager(),
ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter,
FwdingLog: chanDB.ForwardingLog(),
SwitchPackager: channeldb.NewSwitchPackager(),
ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter,
FetchLastChannelUpdate: fetchLastChanUpdate(s, serializedPubKey),
})
if err != nil {
return nil, err
}
s.htlcSwitch = htlcSwitch
// If external IP addresses have been specified, add those to the list
// of this server's addresses. We need to use the cfg.net.ResolveTCPAddr