diff --git a/htlcswitch/link.go b/htlcswitch/link.go index dfb73436..b08a975a 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -982,10 +982,20 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { reason lnwire.OpaqueReason ) - failure := lnwire.NewTemporaryChannelFailure(nil) + var failure lnwire.FailureMessage + update, err := l.cfg.FetchLastChannelUpdate( + l.ShortChanID(), + ) + if err != nil { + failure = &lnwire.FailTemporaryNodeFailure{} + } else { + failure = lnwire.NewTemporaryChannelFailure( + update, + ) + } - // Encrypt the error back to the source unless the payment was - // generated locally. + // Encrypt the error back to the source unless + // the payment was generated locally. if pkt.obfuscator == nil { var b bytes.Buffer err := lnwire.EncodeFailure(&b, failure, 0) @@ -1652,11 +1662,9 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate( - l.shortChanID, - ) + update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) if err != nil { - failure = lnwire.NewTemporaryChannelFailure(nil) + failure = &lnwire.FailTemporaryNodeFailure{} } else { failure = lnwire.NewAmountBelowMinimum( amtToForward, *update, @@ -1686,11 +1694,9 @@ func (l *channelLink) HtlcSatifiesPolicy(payHash [32]byte, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. var failure lnwire.FailureMessage - update, err := l.cfg.FetchLastChannelUpdate( - l.shortChanID, - ) + update, err := l.cfg.FetchLastChannelUpdate(l.ShortChanID()) if err != nil { - failure = lnwire.NewTemporaryChannelFailure(nil) + failure = &lnwire.FailTemporaryNodeFailure{} } else { failure = lnwire.NewFeeInsufficient( amtToForward, *update, @@ -2242,10 +2248,12 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, var failure lnwire.FailureMessage update, err := l.cfg.FetchLastChannelUpdate( - l.shortChanID, + l.ShortChanID(), ) if err != nil { - failure = lnwire.NewTemporaryChannelFailure(nil) + failure = lnwire.NewTemporaryChannelFailure( + update, + ) } else { failure = lnwire.NewExpiryTooSoon(*update) } @@ -2275,7 +2283,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, // sending node is up to date with our current // policy. update, err := l.cfg.FetchLastChannelUpdate( - l.shortChanID, + l.ShortChanID(), ) if err != nil { l.fail("unable to create channel update "+ @@ -2313,7 +2321,17 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, log.Errorf("unable to encode the "+ "remaining route %v", err) - failure := lnwire.NewTemporaryChannelFailure(nil) + var failure lnwire.FailureMessage + update, err := l.cfg.FetchLastChannelUpdate( + l.ShortChanID(), + ) + if err != nil { + failure = &lnwire.FailTemporaryNodeFailure{} + } else { + failure = lnwire.NewTemporaryChannelFailure( + update, + ) + } l.sendHTLCError( pd.HtlcIndex, failure, obfuscator, pd.SourceRef, diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 3bd7c4d1..4e5536b6 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1459,8 +1459,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( } aliceDb := aliceChannel.State().Db - - aliceSwitch, err := New(Config{DB: aliceDb}) + aliceSwitch, err := initSwitchWithDB(aliceDb) if err != nil { return nil, nil, nil, nil, nil, err } @@ -3854,7 +3853,7 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch, if aliceSwitch == nil { var err error - aliceSwitch, err = New(Config{DB: aliceDb}) + aliceSwitch, err = initSwitchWithDB(aliceDb) if err != nil { return nil, nil, nil, err } diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 33898293..33e3e1a9 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -140,6 +140,9 @@ func initSwitchWithDB(db *channeldb.DB) (*Switch, error) { FwdingLog: &mockForwardingLog{ events: make(map[time.Time]channeldb.ForwardingEvent), }, + FetchLastChannelUpdate: func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { + return nil, nil + }, }) } diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index cf1750c8..8028ef93 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -135,6 +135,13 @@ type Config struct { // error encrypters stored in the circuit map on restarts, since they // are not stored directly within the database. ExtractErrorEncrypter ErrorEncrypterExtracter + + // FetchLastChannelUpdate retrieves the latest routing policy for a + // target channel. This channel will typically be the outgoing channel + // specified when we receive an incoming HTLC. This will be used to + // provide payment senders our latest policy when sending encrypted + // error messages. + FetchLastChannelUpdate func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) } // Switch is the central messaging bus for all incoming/outgoing HTLCs. @@ -458,7 +465,15 @@ func (s *Switch) forward(packet *htlcPacket) error { return err } - failure := lnwire.NewTemporaryChannelFailure(nil) + var failure lnwire.FailureMessage + update, err := s.cfg.FetchLastChannelUpdate( + packet.incomingChanID, + ) + if err != nil { + failure = &lnwire.FailTemporaryNodeFailure{} + } else { + failure = lnwire.NewTemporaryChannelFailure(update) + } addErr := ErrIncompleteForward return s.failAddPacket(packet, failure, addErr) @@ -588,14 +603,25 @@ func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error { // Lastly, for any packets that failed, this implies that they were // left in a half added state, which can happen when recovering from // failures. - for _, packet := range failedPackets { - failure := lnwire.NewTemporaryChannelFailure(nil) - addErr := errors.Errorf("failing packet after detecting " + - "incomplete forward") + if len(failedPackets) > 0 { + var failure lnwire.FailureMessage + update, err := s.cfg.FetchLastChannelUpdate( + failedPackets[0].incomingChanID, + ) + if err != nil { + failure = &lnwire.FailTemporaryNodeFailure{} + } else { + failure = lnwire.NewTemporaryChannelFailure(update) + } - // We don't handle the error here since this method always - // returns an error. - s.failAddPacket(packet, failure, addErr) + for _, packet := range failedPackets { + addErr := errors.Errorf("failing packet after " + + "detecting incomplete forward") + + // We don't handle the error here since this method + // always returns an error. + s.failAddPacket(packet, failure, addErr) + } } return errChan @@ -749,6 +775,8 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { htlc.Amount, largestBandwidth) log.Error(err) + // Note that we don't need to populate an update here, + // as this will go directly back to the router. htlcErr := lnwire.NewTemporaryChannelFailure(nil) return &ForwardingError{ ErrorSource: s.cfg.SelfKey, @@ -812,6 +840,10 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, userErr = fmt.Sprintf("unable to decode onion failure, "+ "htlc with hash(%x): %v", payment.paymentHash[:], err) log.Error(userErr) + + // As this didn't even clear the link, we don't need to + // apply an update here since it goes directly to the + // router. failureMsg = lnwire.NewTemporaryChannelFailure(nil) } failure = &ForwardingError{ @@ -938,7 +970,16 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // If packet was forwarded from another channel link // than we should notify this link that some error // occurred. - failure := lnwire.NewTemporaryChannelFailure(nil) + var failure lnwire.FailureMessage + update, err := s.cfg.FetchLastChannelUpdate( + packet.outgoingChanID, + ) + if err != nil { + failure = &lnwire.FailTemporaryNodeFailure{} + } else { + failure = lnwire.NewTemporaryChannelFailure(update) + } + addErr := errors.Errorf("unable to find appropriate "+ "channel link insufficient capacity, need "+ "%v", htlc.Amount) @@ -1799,7 +1840,7 @@ func (s *Switch) UpdateShortChanID(chanID lnwire.ChannelID, s.indexMtx.Lock() - // First, we'll extract the current link as is from the link + // First, we'll extract the current link as is from the link // index. If the link isn't even in the index, then we'll return an // error. link, ok := s.linkIndex[chanID] diff --git a/peer.go b/peer.go index ad9dc946..517f4bfe 100644 --- a/peer.go +++ b/peer.go @@ -21,7 +21,6 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/routing" "github.com/roasbeef/btcd/chaincfg/chainhash" "github.com/roasbeef/btcd/connmgr" "github.com/roasbeef/btcd/txscript" @@ -414,7 +413,7 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error { DecodeHopIterators: p.server.sphinx.DecodeHopIterators, ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter, FetchLastChannelUpdate: fetchLastChanUpdate( - p.server.chanRouter, p.PubKey(), + p.server, p.PubKey(), ), DebugHTLC: cfg.DebugHTLC, HodlMask: cfg.Hodl.Mask(), @@ -1392,7 +1391,7 @@ out: DecodeHopIterators: p.server.sphinx.DecodeHopIterators, ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter, FetchLastChannelUpdate: fetchLastChanUpdate( - p.server.chanRouter, p.PubKey(), + p.server, p.PubKey(), ), DebugHTLC: cfg.DebugHTLC, HodlMask: cfg.Hodl.Mask(), @@ -1911,11 +1910,11 @@ func (p *peer) PubKey() [33]byte { // fetchLastChanUpdate returns a function which is able to retrieve the last // channel update for a target channel. -func fetchLastChanUpdate(router *routing.ChannelRouter, +func fetchLastChanUpdate(s *server, pubKey [33]byte) func(lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { return func(cid lnwire.ShortChannelID) (*lnwire.ChannelUpdate, error) { - info, edge1, edge2, err := router.GetChannelByID(cid) + info, edge1, edge2, err := s.chanRouter.GetChannelByID(cid) if err != nil { return nil, err } diff --git a/server.go b/server.go index 9c259f44..3f5cbe94 100644 --- a/server.go +++ b/server.go @@ -153,7 +153,8 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl, globalFeatures := lnwire.NewRawFeatureVector() - serializedPubKey := privKey.PubKey().SerializeCompressed() + var serializedPubKey [33]byte + copy(serializedPubKey[:], privKey.PubKey().SerializeCompressed()) // Initialize the sphinx router, placing it's persistent replay log in // the same directory as the channel graph database. @@ -175,7 +176,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl, // TODO(roasbeef): derive proper onion key based on rotation // schedule sphinx: htlcswitch.NewOnionProcessor(sphinxRouter), - lightningID: sha256.Sum256(serializedPubKey), + lightningID: sha256.Sum256(serializedPubKey[:]), persistentPeers: make(map[string]struct{}), persistentPeersBackoff: make(map[string]time.Duration), @@ -209,7 +210,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl, debugPre[:], debugHash[:]) } - htlcSwitch, err := htlcswitch.New(htlcswitch.Config{ + s.htlcSwitch, err = htlcswitch.New(htlcswitch.Config{ DB: chanDB, SelfKey: s.identityPriv.PubKey(), LocalChannelClose: func(pubKey []byte, @@ -234,14 +235,14 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl, pubKey[:], err) } }, - FwdingLog: chanDB.ForwardingLog(), - SwitchPackager: channeldb.NewSwitchPackager(), - ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter, + FwdingLog: chanDB.ForwardingLog(), + SwitchPackager: channeldb.NewSwitchPackager(), + ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter, + FetchLastChannelUpdate: fetchLastChanUpdate(s, serializedPubKey), }) if err != nil { return nil, err } - s.htlcSwitch = htlcSwitch // If external IP addresses have been specified, add those to the list // of this server's addresses. We need to use the cfg.net.ResolveTCPAddr