diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index c3d77d9f..fb3cfc56 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1112,21 +1112,26 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) { // Send payment and expose err channel. err = n.aliceServer.htlcSwitch.SendHTLC( n.firstBobChannelLink.ShortChanID(), pid, htlc, - newMockDeobfuscator(), ) if err != nil { t.Fatalf("unable to get send payment: %v", err) } - resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(pid) + resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) if err != nil { t.Fatalf("unable to get payment result: %v", err) } var result *PaymentResult + var ok bool select { - case result = <-resultChan: + case result, ok = <-resultChan: + if !ok { + t.Fatalf("unexpected shutdown") + } case <-time.After(5 * time.Second): t.Fatalf("no result arrive") } @@ -3888,19 +3893,24 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { // properly. err = n.aliceServer.htlcSwitch.SendHTLC( n.firstBobChannelLink.ShortChanID(), pid, htlc, - newMockDeobfuscator(), ) if err != nil { t.Fatalf("unable to send payment to carol: %v", err) } - resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(pid) + resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) if err != nil { t.Fatalf("unable to get payment result: %v", err) } select { - case result := <-resultChan: + case result, ok := <-resultChan: + if !ok { + t.Fatalf("unexpected shutdown") + } + if result.Error != nil { t.Fatalf("payment failed: %v", result.Error) } @@ -3912,7 +3922,6 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { // as it's a duplicate request. err = n.aliceServer.htlcSwitch.SendHTLC( n.firstBobChannelLink.ShortChanID(), pid, htlc, - newMockDeobfuscator(), ) if err != ErrAlreadyPaid { t.Fatalf("ErrAlreadyPaid should have been received got: %v", err) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index e7af896f..c850783e 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -71,12 +71,7 @@ type pendingPayment struct { paymentHash lntypes.Hash amount lnwire.MilliSatoshi - resultChan chan *PaymentResult - - // deobfuscator is a serializable entity which is used if we received - // an error, it deobfuscates the onion failure blob, and extracts the - // exact error from it. - deobfuscator ErrorDecrypter + resultChan chan *networkResult } // plexPacket encapsulates switch packet and adds error channel to receive @@ -347,9 +342,13 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro // GetPaymentResult returns the the result of the payment attempt with the // given paymentID. The method returns a channel where the payment result will -// be sent when available, or an error is encountered. If the paymentID is -// unknown, ErrPaymentIDNotFound will be returned. -func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, error) { +// be sent when available, or an error is encountered during forwarding. When a +// result is received on the channel, the HTLC is guaranteed to no longer be in +// flight. The switch shutting down is signaled by closing the channel. If the +// paymentID is unknown, ErrPaymentIDNotFound will be returned. +func (s *Switch) GetPaymentResult(paymentID uint64, + deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) { + s.pendingMutex.Lock() payment, ok := s.pendingPayments[paymentID] s.pendingMutex.Unlock() @@ -358,7 +357,42 @@ func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, erro return nil, ErrPaymentIDNotFound } - return payment.resultChan, nil + resultChan := make(chan *PaymentResult, 1) + + // Since the payment was known, we can start a goroutine that can + // extract the result when it is available, and pass it on to the + // caller. + s.wg.Add(1) + go func() { + defer s.wg.Done() + + var n *networkResult + select { + case n = <-payment.resultChan: + case <-s.quit: + // We close the result channel to signal a shutdown. We + // don't send any result in this case since the HTLC is + // still in flight. + close(resultChan) + return + } + + // Extract the result and pass it to the result channel. + result, err := s.extractResult( + deobfuscator, n, paymentID, payment.paymentHash, + ) + if err != nil { + e := fmt.Errorf("Unable to extract result: %v", err) + log.Error(e) + resultChan <- &PaymentResult{ + Error: e, + } + return + } + resultChan <- result + }() + + return resultChan, nil } // SendHTLC is used by other subsystems which aren't belong to htlc switch @@ -366,7 +400,7 @@ func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, erro // for this HTLC, and MUST be used only once, otherwise the switch might reject // it. func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, - htlc *lnwire.UpdateAddHTLC, deobfuscator ErrorDecrypter) error { + htlc *lnwire.UpdateAddHTLC) error { // Before sending, double check that we don't already have 1) an // in-flight payment to this payment hash, or 2) a complete payment for @@ -378,10 +412,9 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, // Create payment and add to the map of payment in order later to be // able to retrieve it and return response to the user. payment := &pendingPayment{ - resultChan: make(chan *PaymentResult, 1), - paymentHash: htlc.PaymentHash, - amount: htlc.Amount, - deobfuscator: deobfuscator, + resultChan: make(chan *networkResult, 1), + paymentHash: htlc.PaymentHash, + amount: htlc.Amount, } s.pendingMutex.Lock() @@ -889,25 +922,16 @@ func (s *Switch) handleLocalResponse(pkt *htlcPacket) { isResolution: pkt.isResolution, } - result, err := s.extractResult( - payment, n, pkt.incomingHTLCID, - pkt.circuit.PaymentHash, - ) - if err != nil { - log.Errorf("Unable to extract result: %v", err) - return - } - // Deliver the payment error and preimage to the application, if it is // waiting for a response. if payment != nil { - payment.resultChan <- result + payment.resultChan <- n } } // extractResult uses the given deobfuscator to extract the payment result from // the given network message. -func (s *Switch) extractResult(payment *pendingPayment, n *networkResult, +func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult, paymentID uint64, paymentHash lntypes.Hash) (*PaymentResult, error) { switch htlc := n.msg.(type) { @@ -940,7 +964,7 @@ func (s *Switch) extractResult(payment *pendingPayment, n *networkResult, "%x: %v", paymentHash, err) } paymentErr := s.parseFailedPayment( - payment, paymentID, payment.paymentHash, n.unencrypted, + deobfuscator, paymentID, paymentHash, n.unencrypted, n.isResolution, htlc, ) @@ -961,9 +985,9 @@ func (s *Switch) extractResult(payment *pendingPayment, n *networkResult, // reason attached. // 3) A failure from the remote party, which will need to be decrypted using // the payment deobfuscator. -func (s *Switch) parseFailedPayment(payment *pendingPayment, paymentID uint64, - paymentHash lntypes.Hash, unencrypted, isResolution bool, - htlc *lnwire.UpdateFailHTLC) *ForwardingError { +func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter, + paymentID uint64, paymentHash lntypes.Hash, unencrypted, + isResolution bool, htlc *lnwire.UpdateFailHTLC) *ForwardingError { var failure *ForwardingError @@ -1007,25 +1031,13 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, paymentID uint64, FailureMessage: lnwire.FailPermanentChannelFailure{}, } - // If the provided payment is nil, we have discarded the error decryptor - // due to a restart. We'll return a fixed error and signal a temporary - // channel failure to the router. - case payment == nil: - userErr := fmt.Sprintf("error decryptor for payment " + - "could not be located, likely due to restart") - failure = &ForwardingError{ - ErrorSource: s.cfg.SelfKey, - ExtraMsg: userErr, - FailureMessage: lnwire.NewTemporaryChannelFailure(nil), - } - // A regular multi-hop payment error that we'll need to // decrypt. default: var err error // We'll attempt to fully decrypt the onion encrypted // error. If we're unable to then we'll bail early. - failure, err = payment.deobfuscator.DecryptError(htlc.Reason) + failure, err = deobfuscator.DecryptError(htlc.Reason) if err != nil { userErr := fmt.Sprintf("unable to de-obfuscate "+ "onion failure (hash=%v, pid=%d): %v", diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index f7eec10f..cd798301 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1417,7 +1417,7 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool, // We'll attempt to send out a new HTLC that has Alice as the first // outgoing link. This should fail as Alice isn't yet able to forward // any active HTLC's. - err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg, nil) + err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg) if err == nil { t.Fatalf("local forward should fail due to inactive link") } @@ -1742,7 +1742,9 @@ func TestSwitchSendPayment(t *testing.T) { // First check that the switch will correctly respond that this payment // ID is unknown. - _, err = s.GetPaymentResult(paymentID) + _, err = s.GetPaymentResult( + paymentID, newMockDeobfuscator(), + ) if err != ErrPaymentIDNotFound { t.Fatalf("expected ErrPaymentIDNotFound, got %v", err) } @@ -1752,19 +1754,25 @@ func TestSwitchSendPayment(t *testing.T) { go func() { err := s.SendHTLC( aliceChannelLink.ShortChanID(), paymentID, update, - newMockDeobfuscator()) + ) if err != nil { errChan <- err return } - resultChan, err := s.GetPaymentResult(paymentID) + resultChan, err := s.GetPaymentResult( + paymentID, newMockDeobfuscator(), + ) if err != nil { errChan <- err return } - result := <-resultChan + result, ok := <-resultChan + if !ok { + errChan <- fmt.Errorf("shutting down") + } + if result.Error != nil { errChan <- result.Error return diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index ad3c1bf8..fa8ff886 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -795,17 +795,23 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer, // Send payment and expose err channel. return invoice, func() error { err := sender.htlcSwitch.SendHTLC( - firstHop, pid, htlc, newMockDeobfuscator(), + firstHop, pid, htlc, ) if err != nil { return err } - resultChan, err := sender.htlcSwitch.GetPaymentResult(pid) + resultChan, err := sender.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) if err != nil { return err } - result := <-resultChan + result, ok := <-resultChan + if !ok { + return fmt.Errorf("shutting down") + } + if result.Error != nil { return result.Error } @@ -1275,20 +1281,26 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer, // Send payment and expose err channel. go func() { err := sender.htlcSwitch.SendHTLC( - firstHop, pid, htlc, newMockDeobfuscator(), + firstHop, pid, htlc, ) if err != nil { paymentErr <- err return } - resultChan, err := sender.htlcSwitch.GetPaymentResult(pid) + resultChan, err := sender.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) if err != nil { paymentErr <- err return } - result := <-resultChan + result, ok := <-resultChan + if !ok { + paymentErr <- fmt.Errorf("shutting down") + } + if result.Error != nil { paymentErr <- result.Error return diff --git a/routing/mock_test.go b/routing/mock_test.go index e3f324ef..03aa2923 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -14,8 +14,7 @@ var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, pid uint64, - _ *lnwire.UpdateAddHTLC, - _ htlcswitch.ErrorDecrypter) error { + _ *lnwire.UpdateAddHTLC) error { if m.onPayment == nil { return nil @@ -44,8 +43,8 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, return nil } -func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64) ( - <-chan *htlcswitch.PaymentResult, error) { +func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64, + _ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) { c := make(chan *htlcswitch.PaymentResult, 1) res, ok := m.results[paymentID] diff --git a/routing/router.go b/routing/router.go index 105b69c2..a03c3c9e 100644 --- a/routing/router.go +++ b/routing/router.go @@ -134,15 +134,16 @@ type PaymentAttemptDispatcher interface { // payment was unsuccessful. SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, - htlcAdd *lnwire.UpdateAddHTLC, - deobfuscator htlcswitch.ErrorDecrypter) error + htlcAdd *lnwire.UpdateAddHTLC) error // GetPaymentResult returns the the result of the payment attempt with // the given paymentID. The method returns a channel where the payment - // result will be sent when available, or an error is encountered. If - // the paymentID is unknown, htlcswitch.ErrPaymentIDNotFound will be - // returned. - GetPaymentResult(paymentID uint64) ( + // result will be sent when available, or an error is encountered + // during forwarding. When a result is received on the channel, the + // HTLC is guaranteed to no longer be in flight. The switch shutting + // down is signaled by closing the channel. If the paymentID is + // unknown, ErrPaymentIDNotFound will be returned. + GetPaymentResult(paymentID uint64, deobfuscator htlcswitch.ErrorDecrypter) ( <-chan *htlcswitch.PaymentResult, error) } @@ -1710,13 +1711,6 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession, route.Hops[0].ChannelID, ) - // Using the created circuit, initialize the error decrypter so we can - // parse+decode any failures incurred by this payment within the - // switch. - errorDecryptor := &htlcswitch.SphinxErrorDecrypter{ - OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), - } - // We generate a new, unique payment ID that we will use for // this HTLC. paymentID, err := r.cfg.NextPaymentID() @@ -1725,7 +1719,7 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession, } err = r.cfg.Payer.SendHTLC( - firstHop, paymentID, htlcAdd, errorDecryptor, + firstHop, paymentID, htlcAdd, ) if err != nil { log.Errorf("Failed sending attempt %d for payment %x to "+ @@ -1740,18 +1734,34 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession, return [32]byte{}, finalOutcome, err } + // Using the created circuit, initialize the error decrypter so we can + // parse+decode any failures incurred by this payment within the + // switch. + errorDecryptor := &htlcswitch.SphinxErrorDecrypter{ + OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), + } + // Now ask the switch to return the result of the payment when // available. - resultChan, err := r.cfg.Payer.GetPaymentResult(paymentID) + resultChan, err := r.cfg.Payer.GetPaymentResult( + paymentID, errorDecryptor, + ) if err != nil { log.Errorf("Failed getting result for paymentID %d "+ "from switch: %v", paymentID, err) return [32]byte{}, true, err } - var result *htlcswitch.PaymentResult + var ( + result *htlcswitch.PaymentResult + ok bool + ) select { - case result = <-resultChan: + case result, ok = <-resultChan: + if !ok { + return [32]byte{}, true, htlcswitch.ErrSwitchExiting + } + case <-r.quit: return [32]byte{}, true, ErrRouterShuttingDown }