htlcswitch+router: move deobfuscator creation to GetPaymentResult call
In this commit we move handing the deobfuscator from the router to the switch from when the payment is initiated, to when the result is queried. We do this because only the router can recreate the deobfuscator after a restart, and we are preparing for being able to handle results across restarts. Since the deobfuscator cannot be nil anymore, we can also get rid of that special case.
This commit is contained in:
parent
f99d0c4c68
commit
cd02c22977
@ -1112,21 +1112,26 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) {
|
||||
// Send payment and expose err channel.
|
||||
err = n.aliceServer.htlcSwitch.SendHTLC(
|
||||
n.firstBobChannelLink.ShortChanID(), pid, htlc,
|
||||
newMockDeobfuscator(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to get send payment: %v", err)
|
||||
}
|
||||
|
||||
resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(pid)
|
||||
resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(
|
||||
pid, newMockDeobfuscator(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to get payment result: %v", err)
|
||||
}
|
||||
|
||||
var result *PaymentResult
|
||||
var ok bool
|
||||
select {
|
||||
|
||||
case result = <-resultChan:
|
||||
case result, ok = <-resultChan:
|
||||
if !ok {
|
||||
t.Fatalf("unexpected shutdown")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("no result arrive")
|
||||
}
|
||||
@ -3888,19 +3893,24 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
|
||||
// properly.
|
||||
err = n.aliceServer.htlcSwitch.SendHTLC(
|
||||
n.firstBobChannelLink.ShortChanID(), pid, htlc,
|
||||
newMockDeobfuscator(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send payment to carol: %v", err)
|
||||
}
|
||||
|
||||
resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(pid)
|
||||
resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult(
|
||||
pid, newMockDeobfuscator(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to get payment result: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
case result, ok := <-resultChan:
|
||||
if !ok {
|
||||
t.Fatalf("unexpected shutdown")
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
t.Fatalf("payment failed: %v", result.Error)
|
||||
}
|
||||
@ -3912,7 +3922,6 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
|
||||
// as it's a duplicate request.
|
||||
err = n.aliceServer.htlcSwitch.SendHTLC(
|
||||
n.firstBobChannelLink.ShortChanID(), pid, htlc,
|
||||
newMockDeobfuscator(),
|
||||
)
|
||||
if err != ErrAlreadyPaid {
|
||||
t.Fatalf("ErrAlreadyPaid should have been received got: %v", err)
|
||||
|
@ -71,12 +71,7 @@ type pendingPayment struct {
|
||||
paymentHash lntypes.Hash
|
||||
amount lnwire.MilliSatoshi
|
||||
|
||||
resultChan chan *PaymentResult
|
||||
|
||||
// deobfuscator is a serializable entity which is used if we received
|
||||
// an error, it deobfuscates the onion failure blob, and extracts the
|
||||
// exact error from it.
|
||||
deobfuscator ErrorDecrypter
|
||||
resultChan chan *networkResult
|
||||
}
|
||||
|
||||
// plexPacket encapsulates switch packet and adds error channel to receive
|
||||
@ -347,9 +342,13 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro
|
||||
|
||||
// GetPaymentResult returns the the result of the payment attempt with the
|
||||
// given paymentID. The method returns a channel where the payment result will
|
||||
// be sent when available, or an error is encountered. If the paymentID is
|
||||
// unknown, ErrPaymentIDNotFound will be returned.
|
||||
func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, error) {
|
||||
// be sent when available, or an error is encountered during forwarding. When a
|
||||
// result is received on the channel, the HTLC is guaranteed to no longer be in
|
||||
// flight. The switch shutting down is signaled by closing the channel. If the
|
||||
// paymentID is unknown, ErrPaymentIDNotFound will be returned.
|
||||
func (s *Switch) GetPaymentResult(paymentID uint64,
|
||||
deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) {
|
||||
|
||||
s.pendingMutex.Lock()
|
||||
payment, ok := s.pendingPayments[paymentID]
|
||||
s.pendingMutex.Unlock()
|
||||
@ -358,7 +357,42 @@ func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, erro
|
||||
return nil, ErrPaymentIDNotFound
|
||||
}
|
||||
|
||||
return payment.resultChan, nil
|
||||
resultChan := make(chan *PaymentResult, 1)
|
||||
|
||||
// Since the payment was known, we can start a goroutine that can
|
||||
// extract the result when it is available, and pass it on to the
|
||||
// caller.
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
|
||||
var n *networkResult
|
||||
select {
|
||||
case n = <-payment.resultChan:
|
||||
case <-s.quit:
|
||||
// We close the result channel to signal a shutdown. We
|
||||
// don't send any result in this case since the HTLC is
|
||||
// still in flight.
|
||||
close(resultChan)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract the result and pass it to the result channel.
|
||||
result, err := s.extractResult(
|
||||
deobfuscator, n, paymentID, payment.paymentHash,
|
||||
)
|
||||
if err != nil {
|
||||
e := fmt.Errorf("Unable to extract result: %v", err)
|
||||
log.Error(e)
|
||||
resultChan <- &PaymentResult{
|
||||
Error: e,
|
||||
}
|
||||
return
|
||||
}
|
||||
resultChan <- result
|
||||
}()
|
||||
|
||||
return resultChan, nil
|
||||
}
|
||||
|
||||
// SendHTLC is used by other subsystems which aren't belong to htlc switch
|
||||
@ -366,7 +400,7 @@ func (s *Switch) GetPaymentResult(paymentID uint64) (<-chan *PaymentResult, erro
|
||||
// for this HTLC, and MUST be used only once, otherwise the switch might reject
|
||||
// it.
|
||||
func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
|
||||
htlc *lnwire.UpdateAddHTLC, deobfuscator ErrorDecrypter) error {
|
||||
htlc *lnwire.UpdateAddHTLC) error {
|
||||
|
||||
// Before sending, double check that we don't already have 1) an
|
||||
// in-flight payment to this payment hash, or 2) a complete payment for
|
||||
@ -378,10 +412,9 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
|
||||
// Create payment and add to the map of payment in order later to be
|
||||
// able to retrieve it and return response to the user.
|
||||
payment := &pendingPayment{
|
||||
resultChan: make(chan *PaymentResult, 1),
|
||||
resultChan: make(chan *networkResult, 1),
|
||||
paymentHash: htlc.PaymentHash,
|
||||
amount: htlc.Amount,
|
||||
deobfuscator: deobfuscator,
|
||||
}
|
||||
|
||||
s.pendingMutex.Lock()
|
||||
@ -889,25 +922,16 @@ func (s *Switch) handleLocalResponse(pkt *htlcPacket) {
|
||||
isResolution: pkt.isResolution,
|
||||
}
|
||||
|
||||
result, err := s.extractResult(
|
||||
payment, n, pkt.incomingHTLCID,
|
||||
pkt.circuit.PaymentHash,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to extract result: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Deliver the payment error and preimage to the application, if it is
|
||||
// waiting for a response.
|
||||
if payment != nil {
|
||||
payment.resultChan <- result
|
||||
payment.resultChan <- n
|
||||
}
|
||||
}
|
||||
|
||||
// extractResult uses the given deobfuscator to extract the payment result from
|
||||
// the given network message.
|
||||
func (s *Switch) extractResult(payment *pendingPayment, n *networkResult,
|
||||
func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult,
|
||||
paymentID uint64, paymentHash lntypes.Hash) (*PaymentResult, error) {
|
||||
|
||||
switch htlc := n.msg.(type) {
|
||||
@ -940,7 +964,7 @@ func (s *Switch) extractResult(payment *pendingPayment, n *networkResult,
|
||||
"%x: %v", paymentHash, err)
|
||||
}
|
||||
paymentErr := s.parseFailedPayment(
|
||||
payment, paymentID, payment.paymentHash, n.unencrypted,
|
||||
deobfuscator, paymentID, paymentHash, n.unencrypted,
|
||||
n.isResolution, htlc,
|
||||
)
|
||||
|
||||
@ -961,9 +985,9 @@ func (s *Switch) extractResult(payment *pendingPayment, n *networkResult,
|
||||
// reason attached.
|
||||
// 3) A failure from the remote party, which will need to be decrypted using
|
||||
// the payment deobfuscator.
|
||||
func (s *Switch) parseFailedPayment(payment *pendingPayment, paymentID uint64,
|
||||
paymentHash lntypes.Hash, unencrypted, isResolution bool,
|
||||
htlc *lnwire.UpdateFailHTLC) *ForwardingError {
|
||||
func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter,
|
||||
paymentID uint64, paymentHash lntypes.Hash, unencrypted,
|
||||
isResolution bool, htlc *lnwire.UpdateFailHTLC) *ForwardingError {
|
||||
|
||||
var failure *ForwardingError
|
||||
|
||||
@ -1007,25 +1031,13 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, paymentID uint64,
|
||||
FailureMessage: lnwire.FailPermanentChannelFailure{},
|
||||
}
|
||||
|
||||
// If the provided payment is nil, we have discarded the error decryptor
|
||||
// due to a restart. We'll return a fixed error and signal a temporary
|
||||
// channel failure to the router.
|
||||
case payment == nil:
|
||||
userErr := fmt.Sprintf("error decryptor for payment " +
|
||||
"could not be located, likely due to restart")
|
||||
failure = &ForwardingError{
|
||||
ErrorSource: s.cfg.SelfKey,
|
||||
ExtraMsg: userErr,
|
||||
FailureMessage: lnwire.NewTemporaryChannelFailure(nil),
|
||||
}
|
||||
|
||||
// A regular multi-hop payment error that we'll need to
|
||||
// decrypt.
|
||||
default:
|
||||
var err error
|
||||
// We'll attempt to fully decrypt the onion encrypted
|
||||
// error. If we're unable to then we'll bail early.
|
||||
failure, err = payment.deobfuscator.DecryptError(htlc.Reason)
|
||||
failure, err = deobfuscator.DecryptError(htlc.Reason)
|
||||
if err != nil {
|
||||
userErr := fmt.Sprintf("unable to de-obfuscate "+
|
||||
"onion failure (hash=%v, pid=%d): %v",
|
||||
|
@ -1417,7 +1417,7 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool,
|
||||
// We'll attempt to send out a new HTLC that has Alice as the first
|
||||
// outgoing link. This should fail as Alice isn't yet able to forward
|
||||
// any active HTLC's.
|
||||
err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg, nil)
|
||||
err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg)
|
||||
if err == nil {
|
||||
t.Fatalf("local forward should fail due to inactive link")
|
||||
}
|
||||
@ -1742,7 +1742,9 @@ func TestSwitchSendPayment(t *testing.T) {
|
||||
|
||||
// First check that the switch will correctly respond that this payment
|
||||
// ID is unknown.
|
||||
_, err = s.GetPaymentResult(paymentID)
|
||||
_, err = s.GetPaymentResult(
|
||||
paymentID, newMockDeobfuscator(),
|
||||
)
|
||||
if err != ErrPaymentIDNotFound {
|
||||
t.Fatalf("expected ErrPaymentIDNotFound, got %v", err)
|
||||
}
|
||||
@ -1752,19 +1754,25 @@ func TestSwitchSendPayment(t *testing.T) {
|
||||
go func() {
|
||||
err := s.SendHTLC(
|
||||
aliceChannelLink.ShortChanID(), paymentID, update,
|
||||
newMockDeobfuscator())
|
||||
)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
resultChan, err := s.GetPaymentResult(paymentID)
|
||||
resultChan, err := s.GetPaymentResult(
|
||||
paymentID, newMockDeobfuscator(),
|
||||
)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
result := <-resultChan
|
||||
result, ok := <-resultChan
|
||||
if !ok {
|
||||
errChan <- fmt.Errorf("shutting down")
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
errChan <- result.Error
|
||||
return
|
||||
|
@ -795,17 +795,23 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer,
|
||||
// Send payment and expose err channel.
|
||||
return invoice, func() error {
|
||||
err := sender.htlcSwitch.SendHTLC(
|
||||
firstHop, pid, htlc, newMockDeobfuscator(),
|
||||
firstHop, pid, htlc,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resultChan, err := sender.htlcSwitch.GetPaymentResult(pid)
|
||||
resultChan, err := sender.htlcSwitch.GetPaymentResult(
|
||||
pid, newMockDeobfuscator(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result := <-resultChan
|
||||
result, ok := <-resultChan
|
||||
if !ok {
|
||||
return fmt.Errorf("shutting down")
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
@ -1275,20 +1281,26 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer,
|
||||
// Send payment and expose err channel.
|
||||
go func() {
|
||||
err := sender.htlcSwitch.SendHTLC(
|
||||
firstHop, pid, htlc, newMockDeobfuscator(),
|
||||
firstHop, pid, htlc,
|
||||
)
|
||||
if err != nil {
|
||||
paymentErr <- err
|
||||
return
|
||||
}
|
||||
|
||||
resultChan, err := sender.htlcSwitch.GetPaymentResult(pid)
|
||||
resultChan, err := sender.htlcSwitch.GetPaymentResult(
|
||||
pid, newMockDeobfuscator(),
|
||||
)
|
||||
if err != nil {
|
||||
paymentErr <- err
|
||||
return
|
||||
}
|
||||
|
||||
result := <-resultChan
|
||||
result, ok := <-resultChan
|
||||
if !ok {
|
||||
paymentErr <- fmt.Errorf("shutting down")
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
paymentErr <- result.Error
|
||||
return
|
||||
|
@ -14,8 +14,7 @@ var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
|
||||
|
||||
func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
|
||||
pid uint64,
|
||||
_ *lnwire.UpdateAddHTLC,
|
||||
_ htlcswitch.ErrorDecrypter) error {
|
||||
_ *lnwire.UpdateAddHTLC) error {
|
||||
|
||||
if m.onPayment == nil {
|
||||
return nil
|
||||
@ -44,8 +43,8 @@ func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64) (
|
||||
<-chan *htlcswitch.PaymentResult, error) {
|
||||
func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64,
|
||||
_ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) {
|
||||
|
||||
c := make(chan *htlcswitch.PaymentResult, 1)
|
||||
res, ok := m.results[paymentID]
|
||||
|
@ -134,15 +134,16 @@ type PaymentAttemptDispatcher interface {
|
||||
// payment was unsuccessful.
|
||||
SendHTLC(firstHop lnwire.ShortChannelID,
|
||||
paymentID uint64,
|
||||
htlcAdd *lnwire.UpdateAddHTLC,
|
||||
deobfuscator htlcswitch.ErrorDecrypter) error
|
||||
htlcAdd *lnwire.UpdateAddHTLC) error
|
||||
|
||||
// GetPaymentResult returns the the result of the payment attempt with
|
||||
// the given paymentID. The method returns a channel where the payment
|
||||
// result will be sent when available, or an error is encountered. If
|
||||
// the paymentID is unknown, htlcswitch.ErrPaymentIDNotFound will be
|
||||
// returned.
|
||||
GetPaymentResult(paymentID uint64) (
|
||||
// result will be sent when available, or an error is encountered
|
||||
// during forwarding. When a result is received on the channel, the
|
||||
// HTLC is guaranteed to no longer be in flight. The switch shutting
|
||||
// down is signaled by closing the channel. If the paymentID is
|
||||
// unknown, ErrPaymentIDNotFound will be returned.
|
||||
GetPaymentResult(paymentID uint64, deobfuscator htlcswitch.ErrorDecrypter) (
|
||||
<-chan *htlcswitch.PaymentResult, error)
|
||||
}
|
||||
|
||||
@ -1710,13 +1711,6 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession,
|
||||
route.Hops[0].ChannelID,
|
||||
)
|
||||
|
||||
// Using the created circuit, initialize the error decrypter so we can
|
||||
// parse+decode any failures incurred by this payment within the
|
||||
// switch.
|
||||
errorDecryptor := &htlcswitch.SphinxErrorDecrypter{
|
||||
OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit),
|
||||
}
|
||||
|
||||
// We generate a new, unique payment ID that we will use for
|
||||
// this HTLC.
|
||||
paymentID, err := r.cfg.NextPaymentID()
|
||||
@ -1725,7 +1719,7 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession,
|
||||
}
|
||||
|
||||
err = r.cfg.Payer.SendHTLC(
|
||||
firstHop, paymentID, htlcAdd, errorDecryptor,
|
||||
firstHop, paymentID, htlcAdd,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Failed sending attempt %d for payment %x to "+
|
||||
@ -1740,18 +1734,34 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession,
|
||||
return [32]byte{}, finalOutcome, err
|
||||
}
|
||||
|
||||
// Using the created circuit, initialize the error decrypter so we can
|
||||
// parse+decode any failures incurred by this payment within the
|
||||
// switch.
|
||||
errorDecryptor := &htlcswitch.SphinxErrorDecrypter{
|
||||
OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit),
|
||||
}
|
||||
|
||||
// Now ask the switch to return the result of the payment when
|
||||
// available.
|
||||
resultChan, err := r.cfg.Payer.GetPaymentResult(paymentID)
|
||||
resultChan, err := r.cfg.Payer.GetPaymentResult(
|
||||
paymentID, errorDecryptor,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Failed getting result for paymentID %d "+
|
||||
"from switch: %v", paymentID, err)
|
||||
return [32]byte{}, true, err
|
||||
}
|
||||
|
||||
var result *htlcswitch.PaymentResult
|
||||
var (
|
||||
result *htlcswitch.PaymentResult
|
||||
ok bool
|
||||
)
|
||||
select {
|
||||
case result = <-resultChan:
|
||||
case result, ok = <-resultChan:
|
||||
if !ok {
|
||||
return [32]byte{}, true, htlcswitch.ErrSwitchExiting
|
||||
}
|
||||
|
||||
case <-r.quit:
|
||||
return [32]byte{}, true, ErrRouterShuttingDown
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user