htlcswitch+router: move deobfuscator creation to GetPaymentResult call

In this commit we move handing the deobfuscator from the router to the
switch from when the payment is initiated, to when the result is
queried.

We do this because only the router can recreate the deobfuscator after a
restart, and we are preparing for being able to handle results across
restarts.

Since the deobfuscator cannot be nil anymore, we can also get rid of
that special case.
This commit is contained in:
Johan T. Halseth 2019-05-16 15:27:29 +02:00
parent f99d0c4c68
commit cd02c22977
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
6 changed files with 132 additions and 82 deletions

@ -1112,21 +1112,26 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) {
// Send payment and expose err channel.
err = n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc,
newMockDeobfuscator(),
)
if err != nil {
t.Fatalf("unable to get send payment: %v", err)
}
resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(pid)
resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(
pid, newMockDeobfuscator(),
)
if err != nil {
t.Fatalf("unable to get payment result: %v", err)
}
var result *PaymentResult
var ok bool
select {
case result = <-resultChan:
case result, ok = <-resultChan:
if !ok {
t.Fatalf("unexpected shutdown")
}
case <-time.After(5 * time.Second):
t.Fatalf("no result arrive")
}
@ -3888,19 +3893,24 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
// properly.
err = n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc,
newMockDeobfuscator(),
)
if err != nil {
t.Fatalf("unable to send payment to carol: %v", err)
}
resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(pid)
resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(
pid, newMockDeobfuscator(),
)
if err != nil {
t.Fatalf("unable to get payment result: %v", err)
}
select {
case result := <-resultChan:
case result, ok := <-resultChan:
if !ok {
t.Fatalf("unexpected shutdown")
}
if result.Error != nil {
t.Fatalf("payment failed: %v", result.Error)
}
@ -3912,7 +3922,6 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
// as it's a duplicate request.
err = n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc,
newMockDeobfuscator(),
)
if err != ErrAlreadyPaid {
t.Fatalf("ErrAlreadyPaid should have been received got: %v", err)

@ -71,12 +71,7 @@ type pendingPayment struct {
paymentHash lntypes.Hash
amount lnwire.MilliSatoshi
resultChan chan *PaymentResult
// deobfuscator is a serializable entity which is used if we received
// an error, it deobfuscates the onion failure blob, and extracts the
// exact error from it.
deobfuscator ErrorDecrypter
resultChan chan *networkResult
}
// plexPacket encapsulates switch packet and adds error channel to receive
@ -347,9 +342,13 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro
// GetPaymentResult returns the the result of the payment attempt with the
// given paymentID. The method returns a channel where the payment result will
// be sent when available, or an error is encountered. If the paymentID is
// unknown, ErrPaymentIDNotFound will be returned.
func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, error) {
// be sent when available, or an error is encountered during forwarding. When a
// result is received on the channel, the HTLC is guaranteed to no longer be in
// flight. The switch shutting down is signaled by closing the channel. If the
// paymentID is unknown, ErrPaymentIDNotFound will be returned.
func (s *Switch) GetPaymentResult(paymentID uint64,
deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) {
s.pendingMutex.Lock()
payment, ok := s.pendingPayments[paymentID]
s.pendingMutex.Unlock()
@ -358,7 +357,42 @@ func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, erro
return nil, ErrPaymentIDNotFound
}
return payment.resultChan, nil
resultChan := make(chan *PaymentResult, 1)
// Since the payment was known, we can start a goroutine that can
// extract the result when it is available, and pass it on to the
// caller.
s.wg.Add(1)
go func() {
defer s.wg.Done()
var n *networkResult
select {
case n = <-payment.resultChan:
case <-s.quit:
// We close the result channel to signal a shutdown. We
// don't send any result in this case since the HTLC is
// still in flight.
close(resultChan)
return
}
// Extract the result and pass it to the result channel.
result, err := s.extractResult(
deobfuscator, n, paymentID, payment.paymentHash,
)
if err != nil {
e := fmt.Errorf("Unable to extract result: %v", err)
log.Error(e)
resultChan <- &PaymentResult{
Error: e,
}
return
}
resultChan <- result
}()
return resultChan, nil
}
// SendHTLC is used by other subsystems which aren't belong to htlc switch
@ -366,7 +400,7 @@ func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, erro
// for this HTLC, and MUST be used only once, otherwise the switch might reject
// it.
func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
htlc *lnwire.UpdateAddHTLC, deobfuscator ErrorDecrypter) error {
htlc *lnwire.UpdateAddHTLC) error {
// Before sending, double check that we don't already have 1) an
// in-flight payment to this payment hash, or 2) a complete payment for
@ -378,10 +412,9 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
// Create payment and add to the map of payment in order later to be
// able to retrieve it and return response to the user.
payment := &pendingPayment{
resultChan: make(chan *PaymentResult, 1),
resultChan: make(chan *networkResult, 1),
paymentHash: htlc.PaymentHash,
amount: htlc.Amount,
deobfuscator: deobfuscator,
}
s.pendingMutex.Lock()
@ -889,25 +922,16 @@ func (s *Switch) handleLocalResponse(pkt *htlcPacket) {
isResolution: pkt.isResolution,
}
result, err := s.extractResult(
payment, n, pkt.incomingHTLCID,
pkt.circuit.PaymentHash,
)
if err != nil {
log.Errorf("Unable to extract result: %v", err)
return
}
// Deliver the payment error and preimage to the application, if it is
// waiting for a response.
if payment != nil {
payment.resultChan <- result
payment.resultChan <- n
}
}
// extractResult uses the given deobfuscator to extract the payment result from
// the given network message.
func (s *Switch) extractResult(payment *pendingPayment, n *networkResult,
func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult,
paymentID uint64, paymentHash lntypes.Hash) (*PaymentResult, error) {
switch htlc := n.msg.(type) {
@ -940,7 +964,7 @@ func (s *Switch) extractResult(payment *pendingPayment, n *networkResult,
"%x: %v", paymentHash, err)
}
paymentErr := s.parseFailedPayment(
payment, paymentID, payment.paymentHash, n.unencrypted,
deobfuscator, paymentID, paymentHash, n.unencrypted,
n.isResolution, htlc,
)
@ -961,9 +985,9 @@ func (s *Switch) extractResult(payment *pendingPayment, n *networkResult,
// reason attached.
// 3) A failure from the remote party, which will need to be decrypted using
// the payment deobfuscator.
func (s *Switch) parseFailedPayment(payment *pendingPayment, paymentID uint64,
paymentHash lntypes.Hash, unencrypted, isResolution bool,
htlc *lnwire.UpdateFailHTLC) *ForwardingError {
func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter,
paymentID uint64, paymentHash lntypes.Hash, unencrypted,
isResolution bool, htlc *lnwire.UpdateFailHTLC) *ForwardingError {
var failure *ForwardingError
@ -1007,25 +1031,13 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, paymentID uint64,
FailureMessage: lnwire.FailPermanentChannelFailure{},
}
// If the provided payment is nil, we have discarded the error decryptor
// due to a restart. We'll return a fixed error and signal a temporary
// channel failure to the router.
case payment == nil:
userErr := fmt.Sprintf("error decryptor for payment " +
"could not be located, likely due to restart")
failure = &ForwardingError{
ErrorSource: s.cfg.SelfKey,
ExtraMsg: userErr,
FailureMessage: lnwire.NewTemporaryChannelFailure(nil),
}
// A regular multi-hop payment error that we'll need to
// decrypt.
default:
var err error
// We'll attempt to fully decrypt the onion encrypted
// error. If we're unable to then we'll bail early.
failure, err = payment.deobfuscator.DecryptError(htlc.Reason)
failure, err = deobfuscator.DecryptError(htlc.Reason)
if err != nil {
userErr := fmt.Sprintf("unable to de-obfuscate "+
"onion failure (hash=%v, pid=%d): %v",

@ -1417,7 +1417,7 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool,
// We'll attempt to send out a new HTLC that has Alice as the first
// outgoing link. This should fail as Alice isn't yet able to forward
// any active HTLC's.
err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg, nil)
err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg)
if err == nil {
t.Fatalf("local forward should fail due to inactive link")
}
@ -1742,7 +1742,9 @@ func TestSwitchSendPayment(t *testing.T) {
// First check that the switch will correctly respond that this payment
// ID is unknown.
_, err = s.GetPaymentResult(paymentID)
_, err = s.GetPaymentResult(
paymentID, newMockDeobfuscator(),
)
if err != ErrPaymentIDNotFound {
t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
}
@ -1752,19 +1754,25 @@ func TestSwitchSendPayment(t *testing.T) {
go func() {
err := s.SendHTLC(
aliceChannelLink.ShortChanID(), paymentID, update,
newMockDeobfuscator())
)
if err != nil {
errChan <- err
return
}
resultChan, err := s.GetPaymentResult(paymentID)
resultChan, err := s.GetPaymentResult(
paymentID, newMockDeobfuscator(),
)
if err != nil {
errChan <- err
return
}
result := <-resultChan
result, ok := <-resultChan
if !ok {
errChan <- fmt.Errorf("shutting down")
}
if result.Error != nil {
errChan <- result.Error
return

@ -795,17 +795,23 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer,
// Send payment and expose err channel.
return invoice, func() error {
err := sender.htlcSwitch.SendHTLC(
firstHop, pid, htlc, newMockDeobfuscator(),
firstHop, pid, htlc,
)
if err != nil {
return err
}
resultChan, err := sender.htlcSwitch.GetPaymentResult(pid)
resultChan, err := sender.htlcSwitch.GetPaymentResult(
pid, newMockDeobfuscator(),
)
if err != nil {
return err
}
result := <-resultChan
result, ok := <-resultChan
if !ok {
return fmt.Errorf("shutting down")
}
if result.Error != nil {
return result.Error
}
@ -1275,20 +1281,26 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer,
// Send payment and expose err channel.
go func() {
err := sender.htlcSwitch.SendHTLC(
firstHop, pid, htlc, newMockDeobfuscator(),
firstHop, pid, htlc,
)
if err != nil {
paymentErr <- err
return
}
resultChan, err := sender.htlcSwitch.GetPaymentResult(pid)
resultChan, err := sender.htlcSwitch.GetPaymentResult(
pid, newMockDeobfuscator(),
)
if err != nil {
paymentErr <- err
return
}
result := <-resultChan
result, ok := <-resultChan
if !ok {
paymentErr <- fmt.Errorf("shutting down")
}
if result.Error != nil {
paymentErr <- result.Error
return

@ -14,8 +14,7 @@ var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
pid uint64,
_ *lnwire.UpdateAddHTLC,
_ htlcswitch.ErrorDecrypter) error {
_ *lnwire.UpdateAddHTLC) error {
if m.onPayment == nil {
return nil
@ -44,8 +43,8 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
return nil
}
func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64) (
<-chan *htlcswitch.PaymentResult, error) {
func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64,
_ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) {
c := make(chan *htlcswitch.PaymentResult, 1)
res, ok := m.results[paymentID]

@ -134,15 +134,16 @@ type PaymentAttemptDispatcher interface {
// payment was unsuccessful.
SendHTLC(firstHop lnwire.ShortChannelID,
paymentID uint64,
htlcAdd *lnwire.UpdateAddHTLC,
deobfuscator htlcswitch.ErrorDecrypter) error
htlcAdd *lnwire.UpdateAddHTLC) error
// GetPaymentResult returns the the result of the payment attempt with
// the given paymentID. The method returns a channel where the payment
// result will be sent when available, or an error is encountered. If
// the paymentID is unknown, htlcswitch.ErrPaymentIDNotFound will be
// returned.
GetPaymentResult(paymentID uint64) (
// result will be sent when available, or an error is encountered
// during forwarding. When a result is received on the channel, the
// HTLC is guaranteed to no longer be in flight. The switch shutting
// down is signaled by closing the channel. If the paymentID is
// unknown, ErrPaymentIDNotFound will be returned.
GetPaymentResult(paymentID uint64, deobfuscator htlcswitch.ErrorDecrypter) (
<-chan *htlcswitch.PaymentResult, error)
}
@ -1710,13 +1711,6 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession,
route.Hops[0].ChannelID,
)
// Using the created circuit, initialize the error decrypter so we can
// parse+decode any failures incurred by this payment within the
// switch.
errorDecryptor := &htlcswitch.SphinxErrorDecrypter{
OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit),
}
// We generate a new, unique payment ID that we will use for
// this HTLC.
paymentID, err := r.cfg.NextPaymentID()
@ -1725,7 +1719,7 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession,
}
err = r.cfg.Payer.SendHTLC(
firstHop, paymentID, htlcAdd, errorDecryptor,
firstHop, paymentID, htlcAdd,
)
if err != nil {
log.Errorf("Failed sending attempt %d for payment %x to "+
@ -1740,18 +1734,34 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession,
return [32]byte{}, finalOutcome, err
}
// Using the created circuit, initialize the error decrypter so we can
// parse+decode any failures incurred by this payment within the
// switch.
errorDecryptor := &htlcswitch.SphinxErrorDecrypter{
OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit),
}
// Now ask the switch to return the result of the payment when
// available.
resultChan, err := r.cfg.Payer.GetPaymentResult(paymentID)
resultChan, err := r.cfg.Payer.GetPaymentResult(
paymentID, errorDecryptor,
)
if err != nil {
log.Errorf("Failed getting result for paymentID %d "+
"from switch: %v", paymentID, err)
return [32]byte{}, true, err
}
var result *htlcswitch.PaymentResult
var (
result *htlcswitch.PaymentResult
ok bool
)
select {
case result = <-resultChan:
case result, ok = <-resultChan:
if !ok {
return [32]byte{}, true, htlcswitch.ErrSwitchExiting
}
case <-r.quit:
return [32]byte{}, true, ErrRouterShuttingDown
}