htlcswitch+server: ensure we always send an update w/ a TempChannelFailure
In this commit, we ensure that any time we send a TempChannelFailure that's destined for a multi-hop source sender, then we'll always package the latest channel update along with it.
This commit is contained in:
parent
27ca61aedf
commit
72f48b6abe
@ -982,10 +982,20 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
|
||||
reason lnwire.OpaqueReason
|
||||
)
|
||||
|
||||
failure := lnwire.NewTemporaryChannelFailure(nil)
|
||||
var failure lnwire.FailureMessage
|
||||
update, err := l.cfg.FetchLastChannelUpdate(
|
||||
l.ShortChanID(),
|
||||
)
|
||||
if err != nil {
|
||||
failure = &lnwire.FailTemporaryNodeFailure{}
|
||||
} else {
|
||||
failure = lnwire.NewTemporaryChannelFailure(
|
||||
update,
|
||||
)
|
||||
}
|
||||
|
||||
// Encrypt the error back to the source unless the payment was
|
||||
// generated locally.
|
||||
// Encrypt the error back to the source unless
|
||||
// the payment was generated locally.
|
||||
if pkt.obfuscator == nil {
|
||||
var b bytes.Buffer
|
||||
err := lnwire.EncodeFailure(&b, failure, 0)
|
||||
@ -1652,11 +1662,9 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte,
|
||||
// As part of the returned error, we'll send our latest routing
|
||||
// policy so the sending node obtains the most up to date data.
|
||||
var failure lnwire.FailureMessage
|
||||
update, err := l.cfg.FetchLastChannelUpdate(
|
||||
l.shortChanID,
|
||||
)
|
||||
update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID())
|
||||
if err != nil {
|
||||
failure = lnwire.NewTemporaryChannelFailure(nil)
|
||||
failure = &lnwire.FailTemporaryNodeFailure{}
|
||||
} else {
|
||||
failure = lnwire.NewAmountBelowMinimum(
|
||||
amtToForward, *update,
|
||||
@ -1686,11 +1694,9 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte,
|
||||
// As part of the returned error, we'll send our latest routing
|
||||
// policy so the sending node obtains the most up to date data.
|
||||
var failure lnwire.FailureMessage
|
||||
update, err := l.cfg.FetchLastChannelUpdate(
|
||||
l.shortChanID,
|
||||
)
|
||||
update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID())
|
||||
if err != nil {
|
||||
failure = lnwire.NewTemporaryChannelFailure(nil)
|
||||
failure = &lnwire.FailTemporaryNodeFailure{}
|
||||
} else {
|
||||
failure = lnwire.NewFeeInsufficient(
|
||||
amtToForward, *update,
|
||||
@ -2242,10 +2248,12 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||
|
||||
var failure lnwire.FailureMessage
|
||||
update, err := l.cfg.FetchLastChannelUpdate(
|
||||
l.shortChanID,
|
||||
l.ShortChanID(),
|
||||
)
|
||||
if err != nil {
|
||||
failure = lnwire.NewTemporaryChannelFailure(nil)
|
||||
failure = lnwire.NewTemporaryChannelFailure(
|
||||
update,
|
||||
)
|
||||
} else {
|
||||
failure = lnwire.NewExpiryTooSoon(*update)
|
||||
}
|
||||
@ -2275,7 +2283,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||
// sending node is up to date with our current
|
||||
// policy.
|
||||
update, err := l.cfg.FetchLastChannelUpdate(
|
||||
l.shortChanID,
|
||||
l.ShortChanID(),
|
||||
)
|
||||
if err != nil {
|
||||
l.fail("unable to create channel update "+
|
||||
@ -2313,7 +2321,17 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
|
||||
log.Errorf("unable to encode the "+
|
||||
"remaining route %v", err)
|
||||
|
||||
failure := lnwire.NewTemporaryChannelFailure(nil)
|
||||
var failure lnwire.FailureMessage
|
||||
update, err := l.cfg.FetchLastChannelUpdate(
|
||||
l.ShortChanID(),
|
||||
)
|
||||
if err != nil {
|
||||
failure = &lnwire.FailTemporaryNodeFailure{}
|
||||
} else {
|
||||
failure = lnwire.NewTemporaryChannelFailure(
|
||||
update,
|
||||
)
|
||||
}
|
||||
|
||||
l.sendHTLCError(
|
||||
pd.HtlcIndex, failure, obfuscator, pd.SourceRef,
|
||||
|
@ -1459,8 +1459,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
|
||||
}
|
||||
|
||||
aliceDb := aliceChannel.State().Db
|
||||
|
||||
aliceSwitch, err := New(Config{DB: aliceDb})
|
||||
aliceSwitch, err := initSwitchWithDB(aliceDb)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, err
|
||||
}
|
||||
@ -3854,7 +3853,7 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
|
||||
|
||||
if aliceSwitch == nil {
|
||||
var err error
|
||||
aliceSwitch, err = New(Config{DB: aliceDb})
|
||||
aliceSwitch, err = initSwitchWithDB(aliceDb)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
@ -140,6 +140,9 @@ func initSwitchWithDB(db *channeldb.DB) (*Switch, error) {
|
||||
FwdingLog: &mockForwardingLog{
|
||||
events: make(map[time.Time]channeldb.ForwardingEvent),
|
||||
},
|
||||
FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -135,6 +135,13 @@ type Config struct {
|
||||
// error encrypters stored in the circuit map on restarts, since they
|
||||
// are not stored directly within the database.
|
||||
ExtractErrorEncrypter ErrorEncrypterExtracter
|
||||
|
||||
// FetchLastChannelUpdate retrieves the latest routing policy for a
|
||||
// target channel. This channel will typically be the outgoing channel
|
||||
// specified when we receive an incoming HTLC. This will be used to
|
||||
// provide payment senders our latest policy when sending encrypted
|
||||
// error messages.
|
||||
FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error)
|
||||
}
|
||||
|
||||
// Switch is the central messaging bus for all incoming/outgoing HTLCs.
|
||||
@ -458,7 +465,15 @@ func (s *Switch) forward(packet *htlcPacket) error {
|
||||
return err
|
||||
}
|
||||
|
||||
failure := lnwire.NewTemporaryChannelFailure(nil)
|
||||
var failure lnwire.FailureMessage
|
||||
update, err := s.cfg.FetchLastChannelUpdate(
|
||||
packet.incomingChanID,
|
||||
)
|
||||
if err != nil {
|
||||
failure = &lnwire.FailTemporaryNodeFailure{}
|
||||
} else {
|
||||
failure = lnwire.NewTemporaryChannelFailure(update)
|
||||
}
|
||||
addErr := ErrIncompleteForward
|
||||
|
||||
return s.failAddPacket(packet, failure, addErr)
|
||||
@ -588,14 +603,25 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error {
|
||||
// Lastly, for any packets that failed, this implies that they were
|
||||
// left in a half added state, which can happen when recovering from
|
||||
// failures.
|
||||
for _, packet := range failedPackets {
|
||||
failure := lnwire.NewTemporaryChannelFailure(nil)
|
||||
addErr := errors.Errorf("failing packet after detecting " +
|
||||
"incomplete forward")
|
||||
if len(failedPackets) > 0 {
|
||||
var failure lnwire.FailureMessage
|
||||
update, err := s.cfg.FetchLastChannelUpdate(
|
||||
failedPackets[0].incomingChanID,
|
||||
)
|
||||
if err != nil {
|
||||
failure = &lnwire.FailTemporaryNodeFailure{}
|
||||
} else {
|
||||
failure = lnwire.NewTemporaryChannelFailure(update)
|
||||
}
|
||||
|
||||
// We don't handle the error here since this method always
|
||||
// returns an error.
|
||||
s.failAddPacket(packet, failure, addErr)
|
||||
for _, packet := range failedPackets {
|
||||
addErr := errors.Errorf("failing packet after " +
|
||||
"detecting incomplete forward")
|
||||
|
||||
// We don't handle the error here since this method
|
||||
// always returns an error.
|
||||
s.failAddPacket(packet, failure, addErr)
|
||||
}
|
||||
}
|
||||
|
||||
return errChan
|
||||
@ -749,6 +775,8 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
|
||||
htlc.Amount, largestBandwidth)
|
||||
log.Error(err)
|
||||
|
||||
// Note that we don't need to populate an update here,
|
||||
// as this will go directly back to the router.
|
||||
htlcErr := lnwire.NewTemporaryChannelFailure(nil)
|
||||
return &ForwardingError{
|
||||
ErrorSource: s.cfg.SelfKey,
|
||||
@ -812,6 +840,10 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket,
|
||||
userErr = fmt.Sprintf("unable to decode onion failure, "+
|
||||
"htlc with hash(%x): %v", payment.paymentHash[:], err)
|
||||
log.Error(userErr)
|
||||
|
||||
// As this didn't even clear the link, we don't need to
|
||||
// apply an update here since it goes directly to the
|
||||
// router.
|
||||
failureMsg = lnwire.NewTemporaryChannelFailure(nil)
|
||||
}
|
||||
failure = &ForwardingError{
|
||||
@ -938,7 +970,16 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
||||
// If packet was forwarded from another channel link
|
||||
// than we should notify this link that some error
|
||||
// occurred.
|
||||
failure := lnwire.NewTemporaryChannelFailure(nil)
|
||||
var failure lnwire.FailureMessage
|
||||
update, err := s.cfg.FetchLastChannelUpdate(
|
||||
packet.outgoingChanID,
|
||||
)
|
||||
if err != nil {
|
||||
failure = &lnwire.FailTemporaryNodeFailure{}
|
||||
} else {
|
||||
failure = lnwire.NewTemporaryChannelFailure(update)
|
||||
}
|
||||
|
||||
addErr := errors.Errorf("unable to find appropriate "+
|
||||
"channel link insufficient capacity, need "+
|
||||
"%v", htlc.Amount)
|
||||
@ -1799,7 +1840,7 @@ func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID,
|
||||
|
||||
s.indexMtx.Lock()
|
||||
|
||||
// First, we'll extract the current link as is from the link
|
||||
// First, we'll extract the current link as is from the link
|
||||
// index. If the link isn't even in the index, then we'll return an
|
||||
// error.
|
||||
link, ok := s.linkIndex[chanID]
|
||||
|
9
peer.go
9
peer.go
@ -21,7 +21,6 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/routing"
|
||||
"github.com/roasbeef/btcd/chaincfg/chainhash"
|
||||
"github.com/roasbeef/btcd/connmgr"
|
||||
"github.com/roasbeef/btcd/txscript"
|
||||
@ -414,7 +413,7 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error {
|
||||
DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
|
||||
ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter,
|
||||
FetchLastChannelUpdate: fetchLastChanUpdate(
|
||||
p.server.chanRouter, p.PubKey(),
|
||||
p.server, p.PubKey(),
|
||||
),
|
||||
DebugHTLC: cfg.DebugHTLC,
|
||||
HodlMask: cfg.Hodl.Mask(),
|
||||
@ -1392,7 +1391,7 @@ out:
|
||||
DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
|
||||
ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter,
|
||||
FetchLastChannelUpdate: fetchLastChanUpdate(
|
||||
p.server.chanRouter, p.PubKey(),
|
||||
p.server, p.PubKey(),
|
||||
),
|
||||
DebugHTLC: cfg.DebugHTLC,
|
||||
HodlMask: cfg.Hodl.Mask(),
|
||||
@ -1911,11 +1910,11 @@ func (p *peer) PubKey() [33]byte {
|
||||
|
||||
// fetchLastChanUpdate returns a function which is able to retrieve the last
|
||||
// channel update for a target channel.
|
||||
func fetchLastChanUpdate(router *routing.ChannelRouter,
|
||||
func fetchLastChanUpdate(s *server,
|
||||
pubKey [33]byte) func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
|
||||
|
||||
return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) {
|
||||
info, edge1, edge2, err := router.GetChannelByID(cid)
|
||||
info, edge1, edge2, err := s.chanRouter.GetChannelByID(cid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
15
server.go
15
server.go
@ -153,7 +153,8 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
|
||||
|
||||
globalFeatures := lnwire.NewRawFeatureVector()
|
||||
|
||||
serializedPubKey := privKey.PubKey().SerializeCompressed()
|
||||
var serializedPubKey [33]byte
|
||||
copy(serializedPubKey[:], privKey.PubKey().SerializeCompressed())
|
||||
|
||||
// Initialize the sphinx router, placing it's persistent replay log in
|
||||
// the same directory as the channel graph database.
|
||||
@ -175,7 +176,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
|
||||
// TODO(roasbeef): derive proper onion key based on rotation
|
||||
// schedule
|
||||
sphinx: htlcswitch.NewOnionProcessor(sphinxRouter),
|
||||
lightningID: sha256.Sum256(serializedPubKey),
|
||||
lightningID: sha256.Sum256(serializedPubKey[:]),
|
||||
|
||||
persistentPeers: make(map[string]struct{}),
|
||||
persistentPeersBackoff: make(map[string]time.Duration),
|
||||
@ -209,7 +210,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
|
||||
debugPre[:], debugHash[:])
|
||||
}
|
||||
|
||||
htlcSwitch, err := htlcswitch.New(htlcswitch.Config{
|
||||
s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{
|
||||
DB: chanDB,
|
||||
SelfKey: s.identityPriv.PubKey(),
|
||||
LocalChannelClose: func(pubKey []byte,
|
||||
@ -234,14 +235,14 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
|
||||
pubKey[:], err)
|
||||
}
|
||||
},
|
||||
FwdingLog: chanDB.ForwardingLog(),
|
||||
SwitchPackager: channeldb.NewSwitchPackager(),
|
||||
ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter,
|
||||
FwdingLog: chanDB.ForwardingLog(),
|
||||
SwitchPackager: channeldb.NewSwitchPackager(),
|
||||
ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter,
|
||||
FetchLastChannelUpdate: fetchLastChanUpdate(s, serializedPubKey),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.htlcSwitch = htlcSwitch
|
||||
|
||||
// If external IP addresses have been specified, add those to the list
|
||||
// of this server's addresses. We need to use the cfg.net.ResolveTCPAddr
|
||||
|
Loading…
Reference in New Issue
Block a user