htlcswitch+router: move deobfuscator creation to GetPaymentResult call

In this commit we move handing the deobfuscator from the router to the
switch from when the payment is initiated, to when the result is
queried.

We do this because only the router can recreate the deobfuscator after a
restart, and we are preparing for being able to handle results across
restarts.

Since the deobfuscator cannot be nil anymore, we can also get rid of
that special case.
This commit is contained in:
Johan T. Halseth 2019-05-16 15:27:29 +02:00
parent f99d0c4c68
commit cd02c22977
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
6 changed files with 132 additions and 82 deletions

@ -1112,21 +1112,26 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) {
// Send payment and expose err channel. // Send payment and expose err channel.
err = n.aliceServer.htlcSwitch.SendHTLC( err = n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc, n.firstBobChannelLink.ShortChanID(), pid, htlc,
newMockDeobfuscator(),
) )
if err != nil { if err != nil {
t.Fatalf("unable to get send payment: %v", err) t.Fatalf("unable to get send payment: %v", err)
} }
resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(pid) resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(
pid, newMockDeobfuscator(),
)
if err != nil { if err != nil {
t.Fatalf("unable to get payment result: %v", err) t.Fatalf("unable to get payment result: %v", err)
} }
var result *PaymentResult var result *PaymentResult
var ok bool
select { select {
case result = <-resultChan: case result, ok = <-resultChan:
if !ok {
t.Fatalf("unexpected shutdown")
}
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("no result arrive") t.Fatalf("no result arrive")
} }
@ -3888,19 +3893,24 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
// properly. // properly.
err = n.aliceServer.htlcSwitch.SendHTLC( err = n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc, n.firstBobChannelLink.ShortChanID(), pid, htlc,
newMockDeobfuscator(),
) )
if err != nil { if err != nil {
t.Fatalf("unable to send payment to carol: %v", err) t.Fatalf("unable to send payment to carol: %v", err)
} }
resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(pid) resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(
pid, newMockDeobfuscator(),
)
if err != nil { if err != nil {
t.Fatalf("unable to get payment result: %v", err) t.Fatalf("unable to get payment result: %v", err)
} }
select { select {
case result := <-resultChan: case result, ok := <-resultChan:
if !ok {
t.Fatalf("unexpected shutdown")
}
if result.Error != nil { if result.Error != nil {
t.Fatalf("payment failed: %v", result.Error) t.Fatalf("payment failed: %v", result.Error)
} }
@ -3912,7 +3922,6 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
// as it's a duplicate request. // as it's a duplicate request.
err = n.aliceServer.htlcSwitch.SendHTLC( err = n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc, n.firstBobChannelLink.ShortChanID(), pid, htlc,
newMockDeobfuscator(),
) )
if err != ErrAlreadyPaid { if err != ErrAlreadyPaid {
t.Fatalf("ErrAlreadyPaid should have been received got: %v", err) t.Fatalf("ErrAlreadyPaid should have been received got: %v", err)

@ -71,12 +71,7 @@ type pendingPayment struct {
paymentHash lntypes.Hash paymentHash lntypes.Hash
amount lnwire.MilliSatoshi amount lnwire.MilliSatoshi
resultChan chan *PaymentResult resultChan chan *networkResult
// deobfuscator is a serializable entity which is used if we received
// an error, it deobfuscates the onion failure blob, and extracts the
// exact error from it.
deobfuscator ErrorDecrypter
} }
// plexPacket encapsulates switch packet and adds error channel to receive // plexPacket encapsulates switch packet and adds error channel to receive
@ -347,9 +342,13 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro
// GetPaymentResult returns the the result of the payment attempt with the // GetPaymentResult returns the the result of the payment attempt with the
// given paymentID. The method returns a channel where the payment result will // given paymentID. The method returns a channel where the payment result will
// be sent when available, or an error is encountered. If the paymentID is // be sent when available, or an error is encountered during forwarding. When a
// unknown, ErrPaymentIDNotFound will be returned. // result is received on the channel, the HTLC is guaranteed to no longer be in
func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, error) { // flight. The switch shutting down is signaled by closing the channel. If the
// paymentID is unknown, ErrPaymentIDNotFound will be returned.
func (s *Switch) GetPaymentResult(paymentID uint64,
deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) {
s.pendingMutex.Lock() s.pendingMutex.Lock()
payment, ok := s.pendingPayments[paymentID] payment, ok := s.pendingPayments[paymentID]
s.pendingMutex.Unlock() s.pendingMutex.Unlock()
@ -358,7 +357,42 @@ func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, erro
return nil, ErrPaymentIDNotFound return nil, ErrPaymentIDNotFound
} }
return payment.resultChan, nil resultChan := make(chan *PaymentResult, 1)
// Since the payment was known, we can start a goroutine that can
// extract the result when it is available, and pass it on to the
// caller.
s.wg.Add(1)
go func() {
defer s.wg.Done()
var n *networkResult
select {
case n = <-payment.resultChan:
case <-s.quit:
// We close the result channel to signal a shutdown. We
// don't send any result in this case since the HTLC is
// still in flight.
close(resultChan)
return
}
// Extract the result and pass it to the result channel.
result, err := s.extractResult(
deobfuscator, n, paymentID, payment.paymentHash,
)
if err != nil {
e := fmt.Errorf("Unable to extract result: %v", err)
log.Error(e)
resultChan <- &PaymentResult{
Error: e,
}
return
}
resultChan <- result
}()
return resultChan, nil
} }
// SendHTLC is used by other subsystems which aren't belong to htlc switch // SendHTLC is used by other subsystems which aren't belong to htlc switch
@ -366,7 +400,7 @@ func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, erro
// for this HTLC, and MUST be used only once, otherwise the switch might reject // for this HTLC, and MUST be used only once, otherwise the switch might reject
// it. // it.
func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
htlc *lnwire.UpdateAddHTLC, deobfuscator ErrorDecrypter) error { htlc *lnwire.UpdateAddHTLC) error {
// Before sending, double check that we don't already have 1) an // Before sending, double check that we don't already have 1) an
// in-flight payment to this payment hash, or 2) a complete payment for // in-flight payment to this payment hash, or 2) a complete payment for
@ -378,10 +412,9 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
// Create payment and add to the map of payment in order later to be // Create payment and add to the map of payment in order later to be
// able to retrieve it and return response to the user. // able to retrieve it and return response to the user.
payment := &pendingPayment{ payment := &pendingPayment{
resultChan: make(chan *PaymentResult, 1), resultChan: make(chan *networkResult, 1),
paymentHash: htlc.PaymentHash, paymentHash: htlc.PaymentHash,
amount: htlc.Amount, amount: htlc.Amount,
deobfuscator: deobfuscator,
} }
s.pendingMutex.Lock() s.pendingMutex.Lock()
@ -889,25 +922,16 @@ func (s *Switch) handleLocalResponse(pkt *htlcPacket) {
isResolution: pkt.isResolution, isResolution: pkt.isResolution,
} }
result, err := s.extractResult(
payment, n, pkt.incomingHTLCID,
pkt.circuit.PaymentHash,
)
if err != nil {
log.Errorf("Unable to extract result: %v", err)
return
}
// Deliver the payment error and preimage to the application, if it is // Deliver the payment error and preimage to the application, if it is
// waiting for a response. // waiting for a response.
if payment != nil { if payment != nil {
payment.resultChan <- result payment.resultChan <- n
} }
} }
// extractResult uses the given deobfuscator to extract the payment result from // extractResult uses the given deobfuscator to extract the payment result from
// the given network message. // the given network message.
func (s *Switch) extractResult(payment *pendingPayment, n *networkResult, func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult,
paymentID uint64, paymentHash lntypes.Hash) (*PaymentResult, error) { paymentID uint64, paymentHash lntypes.Hash) (*PaymentResult, error) {
switch htlc := n.msg.(type) { switch htlc := n.msg.(type) {
@ -940,7 +964,7 @@ func (s *Switch) extractResult(payment *pendingPayment, n *networkResult,
"%x: %v", paymentHash, err) "%x: %v", paymentHash, err)
} }
paymentErr := s.parseFailedPayment( paymentErr := s.parseFailedPayment(
payment, paymentID, payment.paymentHash, n.unencrypted, deobfuscator, paymentID, paymentHash, n.unencrypted,
n.isResolution, htlc, n.isResolution, htlc,
) )
@ -961,9 +985,9 @@ func (s *Switch) extractResult(payment *pendingPayment, n *networkResult,
// reason attached. // reason attached.
// 3) A failure from the remote party, which will need to be decrypted using // 3) A failure from the remote party, which will need to be decrypted using
// the payment deobfuscator. // the payment deobfuscator.
func (s *Switch) parseFailedPayment(payment *pendingPayment, paymentID uint64, func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter,
paymentHash lntypes.Hash, unencrypted, isResolution bool, paymentID uint64, paymentHash lntypes.Hash, unencrypted,
htlc *lnwire.UpdateFailHTLC) *ForwardingError { isResolution bool, htlc *lnwire.UpdateFailHTLC) *ForwardingError {
var failure *ForwardingError var failure *ForwardingError
@ -1007,25 +1031,13 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, paymentID uint64,
FailureMessage: lnwire.FailPermanentChannelFailure{}, FailureMessage: lnwire.FailPermanentChannelFailure{},
} }
// If the provided payment is nil, we have discarded the error decryptor
// due to a restart. We'll return a fixed error and signal a temporary
// channel failure to the router.
case payment == nil:
userErr := fmt.Sprintf("error decryptor for payment " +
"could not be located, likely due to restart")
failure = &ForwardingError{
ErrorSource: s.cfg.SelfKey,
ExtraMsg: userErr,
FailureMessage: lnwire.NewTemporaryChannelFailure(nil),
}
// A regular multi-hop payment error that we'll need to // A regular multi-hop payment error that we'll need to
// decrypt. // decrypt.
default: default:
var err error var err error
// We'll attempt to fully decrypt the onion encrypted // We'll attempt to fully decrypt the onion encrypted
// error. If we're unable to then we'll bail early. // error. If we're unable to then we'll bail early.
failure, err = payment.deobfuscator.DecryptError(htlc.Reason) failure, err = deobfuscator.DecryptError(htlc.Reason)
if err != nil { if err != nil {
userErr := fmt.Sprintf("unable to de-obfuscate "+ userErr := fmt.Sprintf("unable to de-obfuscate "+
"onion failure (hash=%v, pid=%d): %v", "onion failure (hash=%v, pid=%d): %v",

@ -1417,7 +1417,7 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool,
// We'll attempt to send out a new HTLC that has Alice as the first // We'll attempt to send out a new HTLC that has Alice as the first
// outgoing link. This should fail as Alice isn't yet able to forward // outgoing link. This should fail as Alice isn't yet able to forward
// any active HTLC's. // any active HTLC's.
err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg, nil) err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg)
if err == nil { if err == nil {
t.Fatalf("local forward should fail due to inactive link") t.Fatalf("local forward should fail due to inactive link")
} }
@ -1742,7 +1742,9 @@ func TestSwitchSendPayment(t *testing.T) {
// First check that the switch will correctly respond that this payment // First check that the switch will correctly respond that this payment
// ID is unknown. // ID is unknown.
_, err = s.GetPaymentResult(paymentID) _, err = s.GetPaymentResult(
paymentID, newMockDeobfuscator(),
)
if err != ErrPaymentIDNotFound { if err != ErrPaymentIDNotFound {
t.Fatalf("expected ErrPaymentIDNotFound, got %v", err) t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
} }
@ -1752,19 +1754,25 @@ func TestSwitchSendPayment(t *testing.T) {
go func() { go func() {
err := s.SendHTLC( err := s.SendHTLC(
aliceChannelLink.ShortChanID(), paymentID, update, aliceChannelLink.ShortChanID(), paymentID, update,
newMockDeobfuscator()) )
if err != nil { if err != nil {
errChan <- err errChan <- err
return return
} }
resultChan, err := s.GetPaymentResult(paymentID) resultChan, err := s.GetPaymentResult(
paymentID, newMockDeobfuscator(),
)
if err != nil { if err != nil {
errChan <- err errChan <- err
return return
} }
result := <-resultChan result, ok := <-resultChan
if !ok {
errChan <- fmt.Errorf("shutting down")
}
if result.Error != nil { if result.Error != nil {
errChan <- result.Error errChan <- result.Error
return return

@ -795,17 +795,23 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer,
// Send payment and expose err channel. // Send payment and expose err channel.
return invoice, func() error { return invoice, func() error {
err := sender.htlcSwitch.SendHTLC( err := sender.htlcSwitch.SendHTLC(
firstHop, pid, htlc, newMockDeobfuscator(), firstHop, pid, htlc,
) )
if err != nil { if err != nil {
return err return err
} }
resultChan, err := sender.htlcSwitch.GetPaymentResult(pid) resultChan, err := sender.htlcSwitch.GetPaymentResult(
pid, newMockDeobfuscator(),
)
if err != nil { if err != nil {
return err return err
} }
result := <-resultChan result, ok := <-resultChan
if !ok {
return fmt.Errorf("shutting down")
}
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
@ -1275,20 +1281,26 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer,
// Send payment and expose err channel. // Send payment and expose err channel.
go func() { go func() {
err := sender.htlcSwitch.SendHTLC( err := sender.htlcSwitch.SendHTLC(
firstHop, pid, htlc, newMockDeobfuscator(), firstHop, pid, htlc,
) )
if err != nil { if err != nil {
paymentErr <- err paymentErr <- err
return return
} }
resultChan, err := sender.htlcSwitch.GetPaymentResult(pid) resultChan, err := sender.htlcSwitch.GetPaymentResult(
pid, newMockDeobfuscator(),
)
if err != nil { if err != nil {
paymentErr <- err paymentErr <- err
return return
} }
result := <-resultChan result, ok := <-resultChan
if !ok {
paymentErr <- fmt.Errorf("shutting down")
}
if result.Error != nil { if result.Error != nil {
paymentErr <- result.Error paymentErr <- result.Error
return return

@ -14,8 +14,7 @@ var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
pid uint64, pid uint64,
_ *lnwire.UpdateAddHTLC, _ *lnwire.UpdateAddHTLC) error {
_ htlcswitch.ErrorDecrypter) error {
if m.onPayment == nil { if m.onPayment == nil {
return nil return nil
@ -44,8 +43,8 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
return nil return nil
} }
func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64) ( func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64,
<-chan *htlcswitch.PaymentResult, error) { _ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) {
c := make(chan *htlcswitch.PaymentResult, 1) c := make(chan *htlcswitch.PaymentResult, 1)
res, ok := m.results[paymentID] res, ok := m.results[paymentID]

@ -134,15 +134,16 @@ type PaymentAttemptDispatcher interface {
// payment was unsuccessful. // payment was unsuccessful.
SendHTLC(firstHop lnwire.ShortChannelID, SendHTLC(firstHop lnwire.ShortChannelID,
paymentID uint64, paymentID uint64,
htlcAdd *lnwire.UpdateAddHTLC, htlcAdd *lnwire.UpdateAddHTLC) error
deobfuscator htlcswitch.ErrorDecrypter) error
// GetPaymentResult returns the the result of the payment attempt with // GetPaymentResult returns the the result of the payment attempt with
// the given paymentID. The method returns a channel where the payment // the given paymentID. The method returns a channel where the payment
// result will be sent when available, or an error is encountered. If // result will be sent when available, or an error is encountered
// the paymentID is unknown, htlcswitch.ErrPaymentIDNotFound will be // during forwarding. When a result is received on the channel, the
// returned. // HTLC is guaranteed to no longer be in flight. The switch shutting
GetPaymentResult(paymentID uint64) ( // down is signaled by closing the channel. If the paymentID is
// unknown, ErrPaymentIDNotFound will be returned.
GetPaymentResult(paymentID uint64, deobfuscator htlcswitch.ErrorDecrypter) (
<-chan *htlcswitch.PaymentResult, error) <-chan *htlcswitch.PaymentResult, error)
} }
@ -1710,13 +1711,6 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession,
route.Hops[0].ChannelID, route.Hops[0].ChannelID,
) )
// Using the created circuit, initialize the error decrypter so we can
// parse+decode any failures incurred by this payment within the
// switch.
errorDecryptor := &htlcswitch.SphinxErrorDecrypter{
OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit),
}
// We generate a new, unique payment ID that we will use for // We generate a new, unique payment ID that we will use for
// this HTLC. // this HTLC.
paymentID, err := r.cfg.NextPaymentID() paymentID, err := r.cfg.NextPaymentID()
@ -1725,7 +1719,7 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession,
} }
err = r.cfg.Payer.SendHTLC( err = r.cfg.Payer.SendHTLC(
firstHop, paymentID, htlcAdd, errorDecryptor, firstHop, paymentID, htlcAdd,
) )
if err != nil { if err != nil {
log.Errorf("Failed sending attempt %d for payment %x to "+ log.Errorf("Failed sending attempt %d for payment %x to "+
@ -1740,18 +1734,34 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession,
return [32]byte{}, finalOutcome, err return [32]byte{}, finalOutcome, err
} }
// Using the created circuit, initialize the error decrypter so we can
// parse+decode any failures incurred by this payment within the
// switch.
errorDecryptor := &htlcswitch.SphinxErrorDecrypter{
OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit),
}
// Now ask the switch to return the result of the payment when // Now ask the switch to return the result of the payment when
// available. // available.
resultChan, err := r.cfg.Payer.GetPaymentResult(paymentID) resultChan, err := r.cfg.Payer.GetPaymentResult(
paymentID, errorDecryptor,
)
if err != nil { if err != nil {
log.Errorf("Failed getting result for paymentID %d "+ log.Errorf("Failed getting result for paymentID %d "+
"from switch: %v", paymentID, err) "from switch: %v", paymentID, err)
return [32]byte{}, true, err return [32]byte{}, true, err
} }
var result *htlcswitch.PaymentResult var (
result *htlcswitch.PaymentResult
ok bool
)
select { select {
case result = <-resultChan: case result, ok = <-resultChan:
if !ok {
return [32]byte{}, true, htlcswitch.ErrSwitchExiting
}
case <-r.quit: case <-r.quit:
return [32]byte{}, true, ErrRouterShuttingDown return [32]byte{}, true, ErrRouterShuttingDown
} }