From cd02c229777c941c60dd6ae5fafeac2f1d890844 Mon Sep 17 00:00:00 2001 From: "Johan T. Halseth" Date: Thu, 16 May 2019 15:27:29 +0200 Subject: [PATCH] htlcswitch+router: move deobfuscator creation to GetPaymentResult call In this commit we move handing the deobfuscator from the router to the switch from when the payment is initiated, to when the result is queried. We do this because only the router can recreate the deobfuscator after a restart, and we are preparing for being able to handle results across restarts. Since the deobfuscator cannot be nil anymore, we can also get rid of that special case. --- htlcswitch/link_test.go | 23 ++++++--- htlcswitch/switch.go | 98 ++++++++++++++++++++++----------------- htlcswitch/switch_test.go | 18 +++++-- htlcswitch/test_utils.go | 24 +++++++--- routing/mock_test.go | 7 ++- routing/router.go | 44 +++++++++++------- 6 files changed, 132 insertions(+), 82 deletions(-) diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index c3d77d9f..fb3cfc56 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1112,21 +1112,26 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) { // Send payment and expose err channel. err = n.aliceServer.htlcSwitch.SendHTLC( n.firstBobChannelLink.ShortChanID(), pid, htlc, - newMockDeobfuscator(), ) if err != nil { t.Fatalf("unable to get send payment: %v", err) } - resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(pid) + resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) if err != nil { t.Fatalf("unable to get payment result: %v", err) } var result *PaymentResult + var ok bool select { - case result = <-resultChan: + case result, ok = <-resultChan: + if !ok { + t.Fatalf("unexpected shutdown") + } case <-time.After(5 * time.Second): t.Fatalf("no result arrive") } @@ -3888,19 +3893,24 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { // properly. err = n.aliceServer.htlcSwitch.SendHTLC( n.firstBobChannelLink.ShortChanID(), pid, htlc, - newMockDeobfuscator(), ) if err != nil { t.Fatalf("unable to send payment to carol: %v", err) } - resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(pid) + resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) if err != nil { t.Fatalf("unable to get payment result: %v", err) } select { - case result := <-resultChan: + case result, ok := <-resultChan: + if !ok { + t.Fatalf("unexpected shutdown") + } + if result.Error != nil { t.Fatalf("payment failed: %v", result.Error) } @@ -3912,7 +3922,6 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { // as it's a duplicate request. err = n.aliceServer.htlcSwitch.SendHTLC( n.firstBobChannelLink.ShortChanID(), pid, htlc, - newMockDeobfuscator(), ) if err != ErrAlreadyPaid { t.Fatalf("ErrAlreadyPaid should have been received got: %v", err) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index e7af896f..c850783e 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -71,12 +71,7 @@ type pendingPayment struct { paymentHash lntypes.Hash amount lnwire.MilliSatoshi - resultChan chan *PaymentResult - - // deobfuscator is a serializable entity which is used if we received - // an error, it deobfuscates the onion failure blob, and extracts the - // exact error from it. - deobfuscator ErrorDecrypter + resultChan chan *networkResult } // plexPacket encapsulates switch packet and adds error channel to receive @@ -347,9 +342,13 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro // GetPaymentResult returns the the result of the payment attempt with the // given paymentID. The method returns a channel where the payment result will -// be sent when available, or an error is encountered. If the paymentID is -// unknown, ErrPaymentIDNotFound will be returned. -func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, error) { +// be sent when available, or an error is encountered during forwarding. When a +// result is received on the channel, the HTLC is guaranteed to no longer be in +// flight. The switch shutting down is signaled by closing the channel. If the +// paymentID is unknown, ErrPaymentIDNotFound will be returned. +func (s *Switch) GetPaymentResult(paymentID uint64, + deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) { + s.pendingMutex.Lock() payment, ok := s.pendingPayments[paymentID] s.pendingMutex.Unlock() @@ -358,7 +357,42 @@ func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, erro return nil, ErrPaymentIDNotFound } - return payment.resultChan, nil + resultChan := make(chan *PaymentResult, 1) + + // Since the payment was known, we can start a goroutine that can + // extract the result when it is available, and pass it on to the + // caller. + s.wg.Add(1) + go func() { + defer s.wg.Done() + + var n *networkResult + select { + case n = <-payment.resultChan: + case <-s.quit: + // We close the result channel to signal a shutdown. We + // don't send any result in this case since the HTLC is + // still in flight. + close(resultChan) + return + } + + // Extract the result and pass it to the result channel. + result, err := s.extractResult( + deobfuscator, n, paymentID, payment.paymentHash, + ) + if err != nil { + e := fmt.Errorf("Unable to extract result: %v", err) + log.Error(e) + resultChan <- &PaymentResult{ + Error: e, + } + return + } + resultChan <- result + }() + + return resultChan, nil } // SendHTLC is used by other subsystems which aren't belong to htlc switch @@ -366,7 +400,7 @@ func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, erro // for this HTLC, and MUST be used only once, otherwise the switch might reject // it. func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, - htlc *lnwire.UpdateAddHTLC, deobfuscator ErrorDecrypter) error { + htlc *lnwire.UpdateAddHTLC) error { // Before sending, double check that we don't already have 1) an // in-flight payment to this payment hash, or 2) a complete payment for @@ -378,10 +412,9 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, // Create payment and add to the map of payment in order later to be // able to retrieve it and return response to the user. payment := &pendingPayment{ - resultChan: make(chan *PaymentResult, 1), - paymentHash: htlc.PaymentHash, - amount: htlc.Amount, - deobfuscator: deobfuscator, + resultChan: make(chan *networkResult, 1), + paymentHash: htlc.PaymentHash, + amount: htlc.Amount, } s.pendingMutex.Lock() @@ -889,25 +922,16 @@ func (s *Switch) handleLocalResponse(pkt *htlcPacket) { isResolution: pkt.isResolution, } - result, err := s.extractResult( - payment, n, pkt.incomingHTLCID, - pkt.circuit.PaymentHash, - ) - if err != nil { - log.Errorf("Unable to extract result: %v", err) - return - } - // Deliver the payment error and preimage to the application, if it is // waiting for a response. if payment != nil { - payment.resultChan <- result + payment.resultChan <- n } } // extractResult uses the given deobfuscator to extract the payment result from // the given network message. -func (s *Switch) extractResult(payment *pendingPayment, n *networkResult, +func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult, paymentID uint64, paymentHash lntypes.Hash) (*PaymentResult, error) { switch htlc := n.msg.(type) { @@ -940,7 +964,7 @@ func (s *Switch) extractResult(payment *pendingPayment, n *networkResult, "%x: %v", paymentHash, err) } paymentErr := s.parseFailedPayment( - payment, paymentID, payment.paymentHash, n.unencrypted, + deobfuscator, paymentID, paymentHash, n.unencrypted, n.isResolution, htlc, ) @@ -961,9 +985,9 @@ func (s *Switch) extractResult(payment *pendingPayment, n *networkResult, // reason attached. // 3) A failure from the remote party, which will need to be decrypted using // the payment deobfuscator. -func (s *Switch) parseFailedPayment(payment *pendingPayment, paymentID uint64, - paymentHash lntypes.Hash, unencrypted, isResolution bool, - htlc *lnwire.UpdateFailHTLC) *ForwardingError { +func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter, + paymentID uint64, paymentHash lntypes.Hash, unencrypted, + isResolution bool, htlc *lnwire.UpdateFailHTLC) *ForwardingError { var failure *ForwardingError @@ -1007,25 +1031,13 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, paymentID uint64, FailureMessage: lnwire.FailPermanentChannelFailure{}, } - // If the provided payment is nil, we have discarded the error decryptor - // due to a restart. We'll return a fixed error and signal a temporary - // channel failure to the router. - case payment == nil: - userErr := fmt.Sprintf("error decryptor for payment " + - "could not be located, likely due to restart") - failure = &ForwardingError{ - ErrorSource: s.cfg.SelfKey, - ExtraMsg: userErr, - FailureMessage: lnwire.NewTemporaryChannelFailure(nil), - } - // A regular multi-hop payment error that we'll need to // decrypt. default: var err error // We'll attempt to fully decrypt the onion encrypted // error. If we're unable to then we'll bail early. - failure, err = payment.deobfuscator.DecryptError(htlc.Reason) + failure, err = deobfuscator.DecryptError(htlc.Reason) if err != nil { userErr := fmt.Sprintf("unable to de-obfuscate "+ "onion failure (hash=%v, pid=%d): %v", diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index f7eec10f..cd798301 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1417,7 +1417,7 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool, // We'll attempt to send out a new HTLC that has Alice as the first // outgoing link. This should fail as Alice isn't yet able to forward // any active HTLC's. - err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg, nil) + err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg) if err == nil { t.Fatalf("local forward should fail due to inactive link") } @@ -1742,7 +1742,9 @@ func TestSwitchSendPayment(t *testing.T) { // First check that the switch will correctly respond that this payment // ID is unknown. - _, err = s.GetPaymentResult(paymentID) + _, err = s.GetPaymentResult( + paymentID, newMockDeobfuscator(), + ) if err != ErrPaymentIDNotFound { t.Fatalf("expected ErrPaymentIDNotFound, got %v", err) } @@ -1752,19 +1754,25 @@ func TestSwitchSendPayment(t *testing.T) { go func() { err := s.SendHTLC( aliceChannelLink.ShortChanID(), paymentID, update, - newMockDeobfuscator()) + ) if err != nil { errChan <- err return } - resultChan, err := s.GetPaymentResult(paymentID) + resultChan, err := s.GetPaymentResult( + paymentID, newMockDeobfuscator(), + ) if err != nil { errChan <- err return } - result := <-resultChan + result, ok := <-resultChan + if !ok { + errChan <- fmt.Errorf("shutting down") + } + if result.Error != nil { errChan <- result.Error return diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index ad3c1bf8..fa8ff886 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -795,17 +795,23 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer, // Send payment and expose err channel. return invoice, func() error { err := sender.htlcSwitch.SendHTLC( - firstHop, pid, htlc, newMockDeobfuscator(), + firstHop, pid, htlc, ) if err != nil { return err } - resultChan, err := sender.htlcSwitch.GetPaymentResult(pid) + resultChan, err := sender.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) if err != nil { return err } - result := <-resultChan + result, ok := <-resultChan + if !ok { + return fmt.Errorf("shutting down") + } + if result.Error != nil { return result.Error } @@ -1275,20 +1281,26 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer, // Send payment and expose err channel. go func() { err := sender.htlcSwitch.SendHTLC( - firstHop, pid, htlc, newMockDeobfuscator(), + firstHop, pid, htlc, ) if err != nil { paymentErr <- err return } - resultChan, err := sender.htlcSwitch.GetPaymentResult(pid) + resultChan, err := sender.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) if err != nil { paymentErr <- err return } - result := <-resultChan + result, ok := <-resultChan + if !ok { + paymentErr <- fmt.Errorf("shutting down") + } + if result.Error != nil { paymentErr <- result.Error return diff --git a/routing/mock_test.go b/routing/mock_test.go index e3f324ef..03aa2923 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -14,8 +14,7 @@ var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, pid uint64, - _ *lnwire.UpdateAddHTLC, - _ htlcswitch.ErrorDecrypter) error { + _ *lnwire.UpdateAddHTLC) error { if m.onPayment == nil { return nil @@ -44,8 +43,8 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, return nil } -func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64) ( - <-chan *htlcswitch.PaymentResult, error) { +func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64, + _ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) { c := make(chan *htlcswitch.PaymentResult, 1) res, ok := m.results[paymentID] diff --git a/routing/router.go b/routing/router.go index 105b69c2..a03c3c9e 100644 --- a/routing/router.go +++ b/routing/router.go @@ -134,15 +134,16 @@ type PaymentAttemptDispatcher interface { // payment was unsuccessful. SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, - htlcAdd *lnwire.UpdateAddHTLC, - deobfuscator htlcswitch.ErrorDecrypter) error + htlcAdd *lnwire.UpdateAddHTLC) error // GetPaymentResult returns the the result of the payment attempt with // the given paymentID. The method returns a channel where the payment - // result will be sent when available, or an error is encountered. If - // the paymentID is unknown, htlcswitch.ErrPaymentIDNotFound will be - // returned. - GetPaymentResult(paymentID uint64) ( + // result will be sent when available, or an error is encountered + // during forwarding. When a result is received on the channel, the + // HTLC is guaranteed to no longer be in flight. The switch shutting + // down is signaled by closing the channel. If the paymentID is + // unknown, ErrPaymentIDNotFound will be returned. + GetPaymentResult(paymentID uint64, deobfuscator htlcswitch.ErrorDecrypter) ( <-chan *htlcswitch.PaymentResult, error) } @@ -1710,13 +1711,6 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession, route.Hops[0].ChannelID, ) - // Using the created circuit, initialize the error decrypter so we can - // parse+decode any failures incurred by this payment within the - // switch. - errorDecryptor := &htlcswitch.SphinxErrorDecrypter{ - OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), - } - // We generate a new, unique payment ID that we will use for // this HTLC. paymentID, err := r.cfg.NextPaymentID() @@ -1725,7 +1719,7 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession, } err = r.cfg.Payer.SendHTLC( - firstHop, paymentID, htlcAdd, errorDecryptor, + firstHop, paymentID, htlcAdd, ) if err != nil { log.Errorf("Failed sending attempt %d for payment %x to "+ @@ -1740,18 +1734,34 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession, return [32]byte{}, finalOutcome, err } + // Using the created circuit, initialize the error decrypter so we can + // parse+decode any failures incurred by this payment within the + // switch. + errorDecryptor := &htlcswitch.SphinxErrorDecrypter{ + OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), + } + // Now ask the switch to return the result of the payment when // available. - resultChan, err := r.cfg.Payer.GetPaymentResult(paymentID) + resultChan, err := r.cfg.Payer.GetPaymentResult( + paymentID, errorDecryptor, + ) if err != nil { log.Errorf("Failed getting result for paymentID %d "+ "from switch: %v", paymentID, err) return [32]byte{}, true, err } - var result *htlcswitch.PaymentResult + var ( + result *htlcswitch.PaymentResult + ok bool + ) select { - case result = <-resultChan: + case result, ok = <-resultChan: + if !ok { + return [32]byte{}, true, htlcswitch.ErrSwitchExiting + } + case <-r.quit: return [32]byte{}, true, ErrRouterShuttingDown }