diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index c1fd4b94..fb3cfc56 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -54,6 +54,8 @@ func newConcurrentTester(t *testing.T) *concurrentTester { } func (c *concurrentTester) Fatalf(format string, args ...interface{}) { + c.T.Helper() + c.mtx.Lock() defer c.mtx.Unlock() @@ -1100,20 +1102,43 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) { // Generate payment invoice and htlc, but don't add this invoice to the // receiver registry. This should trigger an unknown payment hash // failure. - _, htlc, err := generatePayment(amount, htlcAmt, totalTimelock, - blob) + _, htlc, pid, err := generatePayment( + amount, htlcAmt, totalTimelock, blob, + ) if err != nil { t.Fatal(err) } // Send payment and expose err channel. - _, err = n.aliceServer.htlcSwitch.SendHTLC( - n.firstBobChannelLink.ShortChanID(), htlc, - newMockDeobfuscator(), + err = n.aliceServer.htlcSwitch.SendHTLC( + n.firstBobChannelLink.ShortChanID(), pid, htlc, ) - if !strings.Contains(err.Error(), lnwire.CodeUnknownPaymentHash.String()) { - t.Fatalf("expected %v got %v", err, - lnwire.CodeUnknownPaymentHash) + if err != nil { + t.Fatalf("unable to get send payment: %v", err) + } + + resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) + if err != nil { + t.Fatalf("unable to get payment result: %v", err) + } + + var result *PaymentResult + var ok bool + select { + + case result, ok = <-resultChan: + if !ok { + t.Fatalf("unexpected shutdown") + } + case <-time.After(5 * time.Second): + t.Fatalf("no result arrive") + } + + fErr := result.Error + if !strings.Contains(fErr.Error(), lnwire.CodeUnknownPaymentHash.String()) { + t.Fatalf("expected %v got %v", lnwire.CodeUnknownPaymentHash, fErr) } // Wait for Alice to receive the revocation. @@ -1909,7 +1934,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { // a switch initiated payment. The resulting bandwidth should // now be decremented to reflect the new HTLC. htlcAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) - invoice, htlc, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) + invoice, htlc, _, err := generatePayment( + htlcAmt, htlcAmt, 5, mockBlob, + ) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -1989,7 +2016,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { // Next, we'll add another HTLC initiated by the switch (of the same // amount as the prior one). - invoice, htlc, err = generatePayment(htlcAmt, htlcAmt, 5, mockBlob) + invoice, htlc, _, err = generatePayment(htlcAmt, htlcAmt, 5, mockBlob) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -2075,8 +2102,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { if err != nil { t.Fatalf("unable to gen route: %v", err) } - invoice, htlc, err = generatePayment(htlcAmt, htlcAmt, - totalTimelock, blob) + invoice, htlc, _, err = generatePayment( + htlcAmt, htlcAmt, totalTimelock, blob, + ) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -2183,7 +2211,9 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) { if err != nil { t.Fatalf("unable to gen route: %v", err) } - invoice, htlc, err = generatePayment(htlcAmt, htlcAmt, totalTimelock, blob) + invoice, htlc, _, err = generatePayment( + htlcAmt, htlcAmt, totalTimelock, blob, + ) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -2314,7 +2344,9 @@ func TestChannelLinkBandwidthConsistencyOverflow(t *testing.T) { var htlcID uint64 addLinkHTLC := func(id uint64, amt lnwire.MilliSatoshi) [32]byte { - invoice, htlc, err := generatePayment(amt, amt, 5, mockBlob) + invoice, htlc, _, err := generatePayment( + amt, amt, 5, mockBlob, + ) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -2580,7 +2612,7 @@ func TestChannelLinkTrimCircuitsPending(t *testing.T) { // message for the test. var mockBlob [lnwire.OnionPacketSize]byte htlcAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) - _, htlc, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) + _, htlc, _, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -2860,7 +2892,7 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) { // message for the test. var mockBlob [lnwire.OnionPacketSize]byte htlcAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) - _, htlc, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) + _, htlc, _, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -3113,7 +3145,7 @@ func TestChannelLinkBandwidthChanReserve(t *testing.T) { // a switch initiated payment. The resulting bandwidth should // now be decremented to reflect the new HTLC. htlcAmt := lnwire.NewMSatFromSatoshis(3 * btcutil.SatoshiPerBitcoin) - invoice, htlc, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) + invoice, htlc, _, err := generatePayment(htlcAmt, htlcAmt, 5, mockBlob) if err != nil { t.Fatalf("unable to create payment: %v", err) } @@ -3844,8 +3876,9 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { if err != nil { t.Fatal(err) } - invoice, htlc, err := generatePayment(amount, htlcAmt, totalTimelock, - blob) + invoice, htlc, pid, err := generatePayment( + amount, htlcAmt, totalTimelock, blob, + ) if err != nil { t.Fatal(err) } @@ -3858,19 +3891,37 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { // With the invoice now added to Carol's registry, we'll send the // payment. It should succeed w/o any issues as it has been crafted // properly. - _, err = n.aliceServer.htlcSwitch.SendHTLC( - n.firstBobChannelLink.ShortChanID(), htlc, - newMockDeobfuscator(), + err = n.aliceServer.htlcSwitch.SendHTLC( + n.firstBobChannelLink.ShortChanID(), pid, htlc, ) if err != nil { t.Fatalf("unable to send payment to carol: %v", err) } + resultChan, err := n.aliceServer.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) + if err != nil { + t.Fatalf("unable to get payment result: %v", err) + } + + select { + case result, ok := <-resultChan: + if !ok { + t.Fatalf("unexpected shutdown") + } + + if result.Error != nil { + t.Fatalf("payment failed: %v", result.Error) + } + case <-time.After(5 * time.Second): + t.Fatalf("payment result did not arrive") + } + // Now, if we attempt to send the payment *again* it should be rejected // as it's a duplicate request. - _, err = n.aliceServer.htlcSwitch.SendHTLC( - n.firstBobChannelLink.ShortChanID(), htlc, - newMockDeobfuscator(), + err = n.aliceServer.htlcSwitch.SendHTLC( + n.firstBobChannelLink.ShortChanID(), pid, htlc, ) if err != ErrAlreadyPaid { t.Fatalf("ErrAlreadyPaid should have been received got: %v", err) @@ -4255,7 +4306,7 @@ func generateHtlcAndInvoice(t *testing.T, t.Fatalf("unable to generate route: %v", err) } - invoice, htlc, err := generatePayment( + invoice, htlc, _, err := generatePayment( htlcAmt, htlcAmt, uint32(htlcExpiry), blob, ) if err != nil { diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go new file mode 100644 index 00000000..5cfce845 --- /dev/null +++ b/htlcswitch/payment_result.go @@ -0,0 +1,48 @@ +package htlcswitch + +import ( + "errors" + + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // ErrPaymentIDNotFound is an error returned if the given paymentID is + // not found. + ErrPaymentIDNotFound = errors.New("paymentID not found") + + // ErrPaymentIDAlreadyExists is returned if we try to write a pending + // payment whose paymentID already exists. + ErrPaymentIDAlreadyExists = errors.New("paymentID already exists") +) + +// PaymentResult wraps a decoded result received from the network after a +// payment attempt was made. This is what is eventually handed to the router +// for processing. +type PaymentResult struct { + // Preimage is set by the switch in case a sent HTLC was settled. + Preimage [32]byte + + // Error is non-nil in case a HTLC send failed, and the HTLC is now + // irrevocably cancelled. If the payment failed during forwarding, this + // error will be a *ForwardingError. + Error error +} + +// networkResult is the raw result received from the network after a payment +// attempt has been made. Since the switch doesn't always have the necessary +// data to decode the raw message, we store it together with some meta data, +// and decode it when the router query for the final result. +type networkResult struct { + // msg is the received result. This should be of type UpdateFulfillHTLC + // or UpdateFailHTLC. + msg lnwire.Message + + // unencrypted indicates whether the failure encoded in the message is + // unencrypted, and hence doesn't need to be decrypted. + unencrypted bool + + // isResolution indicates whether this is a resolution message, in + // which the failure reason might not be included. + isResolution bool +} diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index f0c1c425..c850783e 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -17,6 +17,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/ticker" @@ -67,16 +68,10 @@ var ( // updates to be received whether the payment has been rejected or proceed // successfully. type pendingPayment struct { - paymentHash lnwallet.PaymentHash + paymentHash lntypes.Hash amount lnwire.MilliSatoshi - preimage chan [sha256.Size]byte - err chan error - - // deobfuscator is a serializable entity which is used if we received - // an error, it deobfuscates the onion failure blob, and extracts the - // exact error from it. - deobfuscator ErrorDecrypter + resultChan chan *networkResult } // plexPacket encapsulates switch packet and adds error channel to receive @@ -213,8 +208,6 @@ type Switch struct { pendingPayments map[uint64]*pendingPayment pendingMutex sync.RWMutex - paymentSequencer Sequencer - // control provides verification of sending htlc mesages control ControlTower @@ -293,16 +286,10 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { return nil, err } - sequencer, err := NewPersistentSequencer(cfg.DB) - if err != nil { - return nil, err - } - return &Switch{ bestHeight: currentHeight, cfg: &cfg, circuits: circuitMap, - paymentSequencer: sequencer, control: NewPaymentControl(false, cfg.DB), linkIndex: make(map[lnwire.ChannelID]ChannelLink), mailOrchestrator: newMailOrchestrator(), @@ -353,35 +340,90 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro return nil } +// GetPaymentResult returns the the result of the payment attempt with the +// given paymentID. The method returns a channel where the payment result will +// be sent when available, or an error is encountered during forwarding. When a +// result is received on the channel, the HTLC is guaranteed to no longer be in +// flight. The switch shutting down is signaled by closing the channel. If the +// paymentID is unknown, ErrPaymentIDNotFound will be returned. +func (s *Switch) GetPaymentResult(paymentID uint64, + deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) { + + s.pendingMutex.Lock() + payment, ok := s.pendingPayments[paymentID] + s.pendingMutex.Unlock() + + if !ok { + return nil, ErrPaymentIDNotFound + } + + resultChan := make(chan *PaymentResult, 1) + + // Since the payment was known, we can start a goroutine that can + // extract the result when it is available, and pass it on to the + // caller. + s.wg.Add(1) + go func() { + defer s.wg.Done() + + var n *networkResult + select { + case n = <-payment.resultChan: + case <-s.quit: + // We close the result channel to signal a shutdown. We + // don't send any result in this case since the HTLC is + // still in flight. + close(resultChan) + return + } + + // Extract the result and pass it to the result channel. + result, err := s.extractResult( + deobfuscator, n, paymentID, payment.paymentHash, + ) + if err != nil { + e := fmt.Errorf("Unable to extract result: %v", err) + log.Error(e) + resultChan <- &PaymentResult{ + Error: e, + } + return + } + resultChan <- result + }() + + return resultChan, nil +} + // SendHTLC is used by other subsystems which aren't belong to htlc switch -// package in order to send the htlc update. -func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, - htlc *lnwire.UpdateAddHTLC, - deobfuscator ErrorDecrypter) ([sha256.Size]byte, error) { +// package in order to send the htlc update. The paymentID used MUST be unique +// for this HTLC, and MUST be used only once, otherwise the switch might reject +// it. +func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, + htlc *lnwire.UpdateAddHTLC) error { // Before sending, double check that we don't already have 1) an // in-flight payment to this payment hash, or 2) a complete payment for // the same hash. if err := s.control.ClearForTakeoff(htlc); err != nil { - return zeroPreimage, err + return err } // Create payment and add to the map of payment in order later to be // able to retrieve it and return response to the user. payment := &pendingPayment{ - err: make(chan error, 1), - preimage: make(chan [sha256.Size]byte, 1), - paymentHash: htlc.PaymentHash, - amount: htlc.Amount, - deobfuscator: deobfuscator, - } - - paymentID, err := s.paymentSequencer.NextID() - if err != nil { - return zeroPreimage, err + resultChan: make(chan *networkResult, 1), + paymentHash: htlc.PaymentHash, + amount: htlc.Amount, } s.pendingMutex.Lock() + if _, ok := s.pendingPayments[paymentID]; ok { + s.pendingMutex.Unlock() + + return ErrPaymentIDAlreadyExists + } + s.pendingPayments[paymentID] = payment s.pendingMutex.Unlock() @@ -398,31 +440,13 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, if err := s.forward(packet); err != nil { s.removePendingPayment(paymentID) if err := s.control.Fail(htlc.PaymentHash); err != nil { - return zeroPreimage, err + return err } - return zeroPreimage, err + return err } - // Returns channels so that other subsystem might wait/skip the - // waiting of handling of payment. - var preimage [sha256.Size]byte - - select { - case e := <-payment.err: - err = e - case <-s.quit: - return zeroPreimage, ErrSwitchExiting - } - - select { - case p := <-payment.preimage: - preimage = p - case <-s.quit: - return zeroPreimage, ErrSwitchExiting - } - - return preimage, err + return nil } // UpdateForwardingPolicies sends a message to the switch to update the @@ -889,12 +913,28 @@ func (s *Switch) handleLocalResponse(pkt *htlcPacket) { // has been restarted since sending the payment. payment := s.findPayment(pkt.incomingHTLCID) - var ( - preimage [32]byte - paymentErr error - ) + // The error reason will be unencypted in case this a local + // failure or a converted error. + unencrypted := pkt.localFailure || pkt.convertedError + n := &networkResult{ + msg: pkt.htlc, + unencrypted: unencrypted, + isResolution: pkt.isResolution, + } - switch htlc := pkt.htlc.(type) { + // Deliver the payment error and preimage to the application, if it is + // waiting for a response. + if payment != nil { + payment.resultChan <- n + } +} + +// extractResult uses the given deobfuscator to extract the payment result from +// the given network message. +func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult, + paymentID uint64, paymentHash lntypes.Hash) (*PaymentResult, error) { + + switch htlc := n.msg.(type) { // We've received a settle update which means we can finalize the user // payment and return successful response. @@ -902,52 +942,52 @@ func (s *Switch) handleLocalResponse(pkt *htlcPacket) { // Persistently mark that a payment to this payment hash // succeeded. This will prevent us from ever making another // payment to this hash. - err := s.control.Success(pkt.circuit.PaymentHash) + err := s.control.Success(paymentHash) if err != nil && err != ErrPaymentAlreadyCompleted { - log.Warnf("Unable to mark completed payment %x: %v", - pkt.circuit.PaymentHash, err) - return + return nil, fmt.Errorf("Unable to mark completed "+ + "payment %x: %v", paymentHash, err) } - preimage = htlc.PaymentPreimage + return &PaymentResult{ + Preimage: htlc.PaymentPreimage, + }, nil - // We've received a fail update which means we can finalize the user - // payment and return fail response. + // We've received a fail update which means we can finalize the + // user payment and return fail response. case *lnwire.UpdateFailHTLC: - // Persistently mark that a payment to this payment hash failed. - // This will permit us to make another attempt at a successful - // payment. - err := s.control.Fail(pkt.circuit.PaymentHash) + // Persistently mark that a payment to this payment hash + // failed. This will permit us to make another attempt at a + // successful payment. + err := s.control.Fail(paymentHash) if err != nil && err != ErrPaymentAlreadyCompleted { - log.Warnf("Unable to ground payment %x: %v", - pkt.circuit.PaymentHash, err) - return + return nil, fmt.Errorf("Unable to ground payment "+ + "%x: %v", paymentHash, err) } + paymentErr := s.parseFailedPayment( + deobfuscator, paymentID, paymentHash, n.unencrypted, + n.isResolution, htlc, + ) - paymentErr = s.parseFailedPayment(payment, pkt, htlc) + return &PaymentResult{ + Error: paymentErr, + }, nil default: - log.Warnf("Received unknown response type: %T", pkt.htlc) - return - } - - // Deliver the payment error and preimage to the application, if it is - // waiting for a response. - if payment != nil { - payment.err <- paymentErr - payment.preimage <- preimage - s.removePendingPayment(pkt.incomingHTLCID) + return nil, fmt.Errorf("Received unknown response type: %T", + htlc) } } // parseFailedPayment determines the appropriate failure message to return to // a user initiated payment. The three cases handled are: -// 1) A local failure, which should already plaintext. -// 2) A resolution from the chain arbitrator, -// 3) A failure from the remote party, which will need to be decrypted using the -// payment deobfuscator. -func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, - htlc *lnwire.UpdateFailHTLC) *ForwardingError { +// 1) An unencrypted failure, which should already plaintext. +// 2) A resolution from the chain arbitrator, which possibly has no failure +// reason attached. +// 3) A failure from the remote party, which will need to be decrypted using +// the payment deobfuscator. +func (s *Switch) parseFailedPayment(deobfuscator ErrorDecrypter, + paymentID uint64, paymentHash lntypes.Hash, unencrypted, + isResolution bool, htlc *lnwire.UpdateFailHTLC) *ForwardingError { var failure *ForwardingError @@ -956,14 +996,14 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, // The payment never cleared the link, so we don't need to // decrypt the error, simply decode it them report back to the // user. - case pkt.localFailure || pkt.convertedError: + case unencrypted: var userErr string r := bytes.NewReader(htlc.Reason) failureMsg, err := lnwire.DecodeFailure(r, 0) if err != nil { - userErr = fmt.Sprintf("unable to decode onion failure, "+ - "htlc with hash(%x): %v", - pkt.circuit.PaymentHash[:], err) + userErr = fmt.Sprintf("unable to decode onion "+ + "failure (hash=%v, pid=%d): %v", + paymentHash, paymentID, err) log.Error(userErr) // As this didn't even clear the link, we don't need to @@ -981,38 +1021,27 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, // the first hop. In this case, we'll report a permanent // channel failure as this means us, or the remote party had to // go on chain. - case pkt.isResolution && htlc.Reason == nil: - userErr := fmt.Sprintf("payment was resolved " + - "on-chain, then cancelled back") + case isResolution && htlc.Reason == nil: + userErr := fmt.Sprintf("payment was resolved "+ + "on-chain, then cancelled back (hash=%v, pid=%d)", + paymentHash, paymentID) failure = &ForwardingError{ ErrorSource: s.cfg.SelfKey, ExtraMsg: userErr, FailureMessage: lnwire.FailPermanentChannelFailure{}, } - // If the provided payment is nil, we have discarded the error decryptor - // due to a restart. We'll return a fixed error and signal a temporary - // channel failure to the router. - case payment == nil: - userErr := fmt.Sprintf("error decryptor for payment " + - "could not be located, likely due to restart") - failure = &ForwardingError{ - ErrorSource: s.cfg.SelfKey, - ExtraMsg: userErr, - FailureMessage: lnwire.NewTemporaryChannelFailure(nil), - } - // A regular multi-hop payment error that we'll need to // decrypt. default: var err error // We'll attempt to fully decrypt the onion encrypted // error. If we're unable to then we'll bail early. - failure, err = payment.deobfuscator.DecryptError(htlc.Reason) + failure, err = deobfuscator.DecryptError(htlc.Reason) if err != nil { - userErr := fmt.Sprintf("unable to de-obfuscate onion "+ - "failure, htlc with hash(%x): %v", - pkt.circuit.PaymentHash[:], err) + userErr := fmt.Sprintf("unable to de-obfuscate "+ + "onion failure (hash=%v, pid=%d): %v", + paymentHash, paymentID, err) log.Error(userErr) failure = &ForwardingError{ ErrorSource: s.cfg.SelfKey, @@ -2206,15 +2235,6 @@ func (s *Switch) CircuitModifier() CircuitModifier { return s.circuits } -// numPendingPayments is helper function which returns the overall number of -// pending user payments. -func (s *Switch) numPendingPayments() int { - s.pendingMutex.RLock() - defer s.pendingMutex.RUnlock() - - return len(s.pendingPayments) -} - // commitCircuits persistently adds a circuit to the switch's circuit map. func (s *Switch) commitCircuits(circuits ...*PaymentCircuit) ( *CircuitFwdActions, error) { diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 75732a2d..cd798301 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -1417,7 +1417,7 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool, // We'll attempt to send out a new HTLC that has Alice as the first // outgoing link. This should fail as Alice isn't yet able to forward // any active HTLC's. - _, err = s.SendHTLC(aliceChannelLink.ShortChanID(), addMsg, nil) + err = s.SendHTLC(aliceChannelLink.ShortChanID(), 0, addMsg) if err == nil { t.Fatalf("local forward should fail due to inactive link") } @@ -1738,24 +1738,47 @@ func TestSwitchSendPayment(t *testing.T) { PaymentHash: rhash, Amount: 1, } + paymentID := uint64(123) + + // First check that the switch will correctly respond that this payment + // ID is unknown. + _, err = s.GetPaymentResult( + paymentID, newMockDeobfuscator(), + ) + if err != ErrPaymentIDNotFound { + t.Fatalf("expected ErrPaymentIDNotFound, got %v", err) + } // Handle the request and checks that bob channel link received it. errChan := make(chan error) go func() { - _, err := s.SendHTLC( - aliceChannelLink.ShortChanID(), update, - newMockDeobfuscator()) - errChan <- err - }() - - go func() { - // Send the payment with the same payment hash and same - // amount and check that it will be propagated successfully - _, err := s.SendHTLC( - aliceChannelLink.ShortChanID(), update, - newMockDeobfuscator(), + err := s.SendHTLC( + aliceChannelLink.ShortChanID(), paymentID, update, ) - errChan <- err + if err != nil { + errChan <- err + return + } + + resultChan, err := s.GetPaymentResult( + paymentID, newMockDeobfuscator(), + ) + if err != nil { + errChan <- err + return + } + + result, ok := <-resultChan + if !ok { + errChan <- fmt.Errorf("shutting down") + } + + if result.Error != nil { + errChan <- result.Error + return + } + + errChan <- nil }() select { @@ -1765,29 +1788,13 @@ func TestSwitchSendPayment(t *testing.T) { } case err := <-errChan: - if err != ErrPaymentInFlight { + if err != nil { t.Fatalf("unable to send payment: %v", err) } case <-time.After(time.Second): t.Fatal("request was not propagated to destination") } - select { - case packet := <-aliceChannelLink.packets: - if err := aliceChannelLink.completeCircuit(packet); err != nil { - t.Fatalf("unable to complete payment circuit: %v", err) - } - - case err := <-errChan: - t.Fatalf("unable to send payment: %v", err) - case <-time.After(time.Second): - t.Fatal("request was not propagated to destination") - } - - if s.numPendingPayments() != 1 { - t.Fatal("wrong amount of pending payments") - } - if s.circuits.NumOpen() != 1 { t.Fatal("wrong amount of circuits") } @@ -1824,10 +1831,6 @@ func TestSwitchSendPayment(t *testing.T) { case <-time.After(time.Second): t.Fatal("err wasn't received") } - - if s.numPendingPayments() != 0 { - t.Fatal("wrong amount of pending payments") - } } // TestLocalPaymentNoForwardingEvents tests that if we send a series of locally diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 2e1907f2..fa8ff886 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -543,7 +543,7 @@ func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) { func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, blob [lnwire.OnionPacketSize]byte, preimage, rhash [32]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC, - error) { + uint64, error) { // Create the db invoice. Normally the payment requests needs to be set, // because it is decoded in InvoiceRegistry to obtain the cltv expiry. @@ -566,18 +566,25 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, OnionBlob: blob, } - return invoice, htlc, nil + pid, err := generateRandomBytes(8) + if err != nil { + return nil, nil, 0, err + } + paymentID := binary.BigEndian.Uint64(pid) + + return invoice, htlc, paymentID, nil } // generatePayment generates the htlc add request by given path blob and // invoice which should be added by destination peer. func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, - blob [lnwire.OnionPacketSize]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC, error) { + blob [lnwire.OnionPacketSize]byte) (*channeldb.Invoice, + *lnwire.UpdateAddHTLC, uint64, error) { var preimage [sha256.Size]byte r, err := generateRandomBytes(sha256.Size) if err != nil { - return nil, nil, err + return nil, nil, 0, err } copy(preimage[:], r) @@ -772,7 +779,9 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer, } // Generate payment: invoice and htlc. - invoice, htlc, err := generatePayment(invoiceAmt, htlcAmt, timelock, blob) + invoice, htlc, pid, err := generatePayment( + invoiceAmt, htlcAmt, timelock, blob, + ) if err != nil { return nil, nil, err } @@ -785,10 +794,29 @@ func preparePayment(sendingPeer, receivingPeer lnpeer.Peer, // Send payment and expose err channel. return invoice, func() error { - _, err := sender.htlcSwitch.SendHTLC( - firstHop, htlc, newMockDeobfuscator(), + err := sender.htlcSwitch.SendHTLC( + firstHop, pid, htlc, ) - return err + if err != nil { + return err + } + resultChan, err := sender.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) + if err != nil { + return err + } + + result, ok := <-resultChan + if !ok { + return fmt.Errorf("shutting down") + } + + if result.Error != nil { + return result.Error + } + + return nil }, nil } @@ -1235,8 +1263,10 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer, rhash := preimage.Hash() // Generate payment: invoice and htlc. - invoice, htlc, err := generatePaymentWithPreimage(invoiceAmt, htlcAmt, timelock, blob, - channeldb.UnknownPreimage, rhash) + invoice, htlc, pid, err := generatePaymentWithPreimage( + invoiceAmt, htlcAmt, timelock, blob, + channeldb.UnknownPreimage, rhash, + ) if err != nil { paymentErr <- err return paymentErr @@ -1250,10 +1280,32 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer, // Send payment and expose err channel. go func() { - _, err := sender.htlcSwitch.SendHTLC( - firstHop, htlc, newMockDeobfuscator(), + err := sender.htlcSwitch.SendHTLC( + firstHop, pid, htlc, ) - paymentErr <- err + if err != nil { + paymentErr <- err + return + } + + resultChan, err := sender.htlcSwitch.GetPaymentResult( + pid, newMockDeobfuscator(), + ) + if err != nil { + paymentErr <- err + return + } + + result, ok := <-resultChan + if !ok { + paymentErr <- fmt.Errorf("shutting down") + } + + if result.Error != nil { + paymentErr <- result.Error + return + } + paymentErr <- nil }() return paymentErr diff --git a/routing/mock_test.go b/routing/mock_test.go new file mode 100644 index 00000000..03aa2923 --- /dev/null +++ b/routing/mock_test.go @@ -0,0 +1,64 @@ +package routing + +import ( + "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/lnwire" +) + +type mockPaymentAttemptDispatcher struct { + onPayment func(firstHop lnwire.ShortChannelID) ([32]byte, error) + results map[uint64]*htlcswitch.PaymentResult +} + +var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil) + +func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID, + pid uint64, + _ *lnwire.UpdateAddHTLC) error { + + if m.onPayment == nil { + return nil + } + + if m.results == nil { + m.results = make(map[uint64]*htlcswitch.PaymentResult) + } + + var result *htlcswitch.PaymentResult + preimage, err := m.onPayment(firstHop) + if err != nil { + fwdErr, ok := err.(*htlcswitch.ForwardingError) + if !ok { + return err + } + result = &htlcswitch.PaymentResult{ + Error: fwdErr, + } + } else { + result = &htlcswitch.PaymentResult{Preimage: preimage} + } + + m.results[pid] = result + + return nil +} + +func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64, + _ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) { + + c := make(chan *htlcswitch.PaymentResult, 1) + res, ok := m.results[paymentID] + if !ok { + return nil, htlcswitch.ErrPaymentIDNotFound + } + c <- res + + return c, nil + +} + +func (m *mockPaymentAttemptDispatcher) setPaymentResult( + f func(firstHop lnwire.ShortChannelID) ([32]byte, error)) { + + m.onPayment = f +} diff --git a/routing/router.go b/routing/router.go index d49533c6..a03c3c9e 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2,7 +2,6 @@ package routing import ( "bytes" - "crypto/sha256" "fmt" "runtime" "sync" @@ -126,6 +125,28 @@ type ChannelGraphSource interface { e1, e2 *channeldb.ChannelEdgePolicy) error) error } +// PaymentAttemptDispatcher is used by the router to send payment attempts onto +// the network, and receive their results. +type PaymentAttemptDispatcher interface { + // SendHTLC is a function that directs a link-layer switch to + // forward a fully encoded payment to the first hop in the route + // denoted by its public key. A non-nil error is to be returned if the + // payment was unsuccessful. + SendHTLC(firstHop lnwire.ShortChannelID, + paymentID uint64, + htlcAdd *lnwire.UpdateAddHTLC) error + + // GetPaymentResult returns the the result of the payment attempt with + // the given paymentID. The method returns a channel where the payment + // result will be sent when available, or an error is encountered + // during forwarding. When a result is received on the channel, the + // HTLC is guaranteed to no longer be in flight. The switch shutting + // down is signaled by closing the channel. If the paymentID is + // unknown, ErrPaymentIDNotFound will be returned. + GetPaymentResult(paymentID uint64, deobfuscator htlcswitch.ErrorDecrypter) ( + <-chan *htlcswitch.PaymentResult, error) +} + // FeeSchema is the set fee configuration for a Lightning Node on the network. // Using the coefficients described within the schema, the required fee to // forward outgoing payments can be derived. @@ -173,13 +194,10 @@ type Config struct { // we need in order to properly maintain the channel graph. ChainView chainview.FilteredChainView - // SendToSwitch is a function that directs a link-layer switch to - // forward a fully encoded payment to the first hop in the route - // denoted by its public key. A non-nil error is to be returned if the - // payment was unsuccessful. - SendToSwitch func(firstHop lnwire.ShortChannelID, - htlcAdd *lnwire.UpdateAddHTLC, - circuit *sphinx.Circuit) ([sha256.Size]byte, error) + // Payer is an instance of a PaymentAttemptDispatcher and is used by + // the router to send payment attempts onto the network, and receive + // their results. + Payer PaymentAttemptDispatcher // ChannelPruneExpiry is the duration used to determine if a channel // should be pruned or not. If the delta between now and when the @@ -199,6 +217,12 @@ type Config struct { // returned. QueryBandwidth func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi + // NextPaymentID is a method that guarantees to return a new, unique ID + // each time it is called. This is used by the router to generate a + // unique payment ID for each payment it attempts to send, such that + // the switch can properly handle the HTLC. + NextPaymentID func() (uint64, error) + // AssumeChannelValid toggles whether or not the router will check for // spentness of channel outpoints. For neutrino, this saves long rescans // from blocking initial usage of the daemon. @@ -1381,12 +1405,22 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, return route, nil } +// generateNewSessionKey generates a new ephemeral private key to be used for a +// payment attempt. +func generateNewSessionKey() (*btcec.PrivateKey, error) { + // Generate a new random session key to ensure that we don't trigger + // any replay. + // + // TODO(roasbeef): add more sources of randomness? + return btcec.NewPrivateKey(btcec.S256()) +} + // generateSphinxPacket generates then encodes a sphinx packet which encodes // the onion route specified by the passed layer 3 route. The blob returned // from this function can immediately be included within an HTLC add packet to // be sent to the first hop within the route. -func generateSphinxPacket(rt *route.Route, paymentHash []byte) ([]byte, - *sphinx.Circuit, error) { +func generateSphinxPacket(rt *route.Route, paymentHash []byte, + sessionKey *btcec.PrivateKey) ([]byte, *sphinx.Circuit, error) { // As a sanity check, we'll ensure that the set of hops has been // properly filled in, otherwise, we won't actually be able to @@ -1410,15 +1444,6 @@ func generateSphinxPacket(rt *route.Route, paymentHash []byte) ([]byte, }), ) - // Generate a new random session key to ensure that we don't trigger - // any replay. - // - // TODO(roasbeef): add more sources of randomness? - sessionKey, err := btcec.NewPrivateKey(btcec.S256()) - if err != nil { - return nil, nil, err - } - // Next generate the onion routing packet which allows us to perform // privacy preserving source routing across the network. sphinxPacket, err := sphinx.NewOnionPacket( @@ -1654,32 +1679,19 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession, }), ) - preimage, err := r.sendToSwitch(route, paymentHash) - if err == nil { - return preimage, true, nil + // Generate a new key to be used for this attempt. + sessionKey, err := generateNewSessionKey() + if err != nil { + return [32]byte{}, true, err } - - log.Errorf("Attempt to send payment %x failed: %v", - paymentHash, err) - - finalOutcome := r.processSendError(paySession, route, err) - - return [32]byte{}, finalOutcome, err -} - -// sendToSwitch sends a payment along the specified route and returns the -// obtained preimage. -func (r *ChannelRouter) sendToSwitch(route *route.Route, paymentHash [32]byte) ( - [32]byte, error) { - // Generate the raw encoded sphinx packet to be included along // with the htlcAdd message that we send directly to the // switch. onionBlob, circuit, err := generateSphinxPacket( - route, paymentHash[:], + route, paymentHash[:], sessionKey, ) if err != nil { - return [32]byte{}, err + return [32]byte{}, true, err } // Craft an HTLC packet to send to the layer 2 switch. The @@ -1698,9 +1710,74 @@ func (r *ChannelRouter) sendToSwitch(route *route.Route, paymentHash [32]byte) ( firstHop := lnwire.NewShortChanIDFromInt( route.Hops[0].ChannelID, ) - return r.cfg.SendToSwitch( - firstHop, htlcAdd, circuit, + + // We generate a new, unique payment ID that we will use for + // this HTLC. + paymentID, err := r.cfg.NextPaymentID() + if err != nil { + return [32]byte{}, true, err + } + + err = r.cfg.Payer.SendHTLC( + firstHop, paymentID, htlcAdd, ) + if err != nil { + log.Errorf("Failed sending attempt %d for payment %x to "+ + "switch: %v", paymentID, paymentHash, err) + + // We must inspect the error to know whether it was critical or + // not, to decide whether we should continue trying. + finalOutcome := r.processSendError( + paySession, route, err, + ) + + return [32]byte{}, finalOutcome, err + } + + // Using the created circuit, initialize the error decrypter so we can + // parse+decode any failures incurred by this payment within the + // switch. + errorDecryptor := &htlcswitch.SphinxErrorDecrypter{ + OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), + } + + // Now ask the switch to return the result of the payment when + // available. + resultChan, err := r.cfg.Payer.GetPaymentResult( + paymentID, errorDecryptor, + ) + if err != nil { + log.Errorf("Failed getting result for paymentID %d "+ + "from switch: %v", paymentID, err) + return [32]byte{}, true, err + } + + var ( + result *htlcswitch.PaymentResult + ok bool + ) + select { + case result, ok = <-resultChan: + if !ok { + return [32]byte{}, true, htlcswitch.ErrSwitchExiting + } + + case <-r.quit: + return [32]byte{}, true, ErrRouterShuttingDown + } + + if result.Error != nil { + log.Errorf("Attempt to send payment %x failed: %v", + paymentHash, result.Error) + + finalOutcome := r.processSendError( + paySession, route, result.Error, + ) + + return [32]byte{}, finalOutcome, result.Error + } + + return result.Preimage, true, nil } // processSendError analyzes the error for the payment attempt received from the diff --git a/routing/router_test.go b/routing/router_test.go index f4367d29..fb516000 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -6,6 +6,7 @@ import ( "image/color" "math/rand" "strings" + "sync/atomic" "testing" "time" @@ -15,7 +16,6 @@ import ( "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" @@ -24,6 +24,8 @@ import ( "github.com/lightningnetwork/lnd/zpay32" ) +var uniquePaymentID uint64 = 1 // to be used atomically + type testCtx struct { router *ChannelRouter @@ -44,13 +46,10 @@ func (c *testCtx) RestartRouter() error { // With the chainView reset, we'll now re-create the router itself, and // start it. router, err := New(Config{ - Graph: c.graph, - Chain: c.chain, - ChainView: c.chainView, - SendToSwitch: func(_ lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { - return [32]byte{}, nil - }, + Graph: c.graph, + Chain: c.chain, + ChainView: c.chainView, + Payer: &mockPaymentAttemptDispatcher{}, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, }) @@ -85,19 +84,19 @@ func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGr chain := newMockChain(startingHeight) chainView := newMockChainView(chain) router, err := New(Config{ - Graph: graphInstance.graph, - Chain: chain, - ChainView: chainView, - SendToSwitch: func(_ lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { - - return [32]byte{}, nil - }, + Graph: graphInstance.graph, + Chain: chain, + ChainView: chainView, + Payer: &mockPaymentAttemptDispatcher{}, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { return lnwire.NewMSatFromSatoshis(e.Capacity) }, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, }) if err != nil { return nil, nil, fmt.Errorf("unable to create router %v", err) @@ -250,24 +249,24 @@ func TestSendPaymentRouteFailureFallback(t *testing.T) { // router's configuration to ignore the path that has luo ji as the // first hop. This should force the router to instead take the // available two hop path (through satoshi). - ctx.router.cfg.SendToSwitch = func(firstHop lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - roasbeefLuoji := lnwire.NewShortChanIDFromInt(689530843) - if firstHop == roasbeefLuoji { - pub, err := sourceNode.PubKey() - if err != nil { - return preImage, err + roasbeefLuoji := lnwire.NewShortChanIDFromInt(689530843) + if firstHop == roasbeefLuoji { + pub, err := sourceNode.PubKey() + if err != nil { + return preImage, err + } + return [32]byte{}, &htlcswitch.ForwardingError{ + ErrorSource: pub, + // TODO(roasbeef): temp node failure should be? + FailureMessage: &lnwire.FailTemporaryChannelFailure{}, + } } - return [32]byte{}, &htlcswitch.ForwardingError{ - ErrorSource: pub, - // TODO(roasbeef): temp node failure should be? - FailureMessage: &lnwire.FailTemporaryChannelFailure{}, - } - } - return preImage, nil - } + return preImage, nil + }) // Send off the payment request to the router, route through satoshi // should've been selected as a fall back and succeeded correctly. @@ -387,24 +386,24 @@ func TestChannelUpdateValidation(t *testing.T) { // We'll modify the SendToSwitch method so that it simulates a failed // payment with an error originating from the first hop of the route. // The unsigned channel update is attached to the failure message. - ctx.router.cfg.SendToSwitch = func(firstHop lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - v := ctx.aliases["b"] - source, err := btcec.ParsePubKey( - v[:], btcec.S256(), - ) - if err != nil { - t.Fatal(err) - } + v := ctx.aliases["b"] + source, err := btcec.ParsePubKey( + v[:], btcec.S256(), + ) + if err != nil { + t.Fatal(err) + } - return [32]byte{}, &htlcswitch.ForwardingError{ - ErrorSource: source, - FailureMessage: &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, - }, - } - } + return [32]byte{}, &htlcswitch.ForwardingError{ + ErrorSource: source, + FailureMessage: &lnwire.FailFeeInsufficient{ + Update: errChanUpdate, + }, + } + }) // The payment parameter is mostly redundant in SendToRoute. Can be left // empty for this test. @@ -518,32 +517,32 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { // We'll now modify the SendToSwitch method to return an error for the // outgoing channel to Son goku. This will be a fee related error, so // it should only cause the edge to be pruned after the second attempt. - ctx.router.cfg.SendToSwitch = func(firstHop lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - roasbeefSongoku := lnwire.NewShortChanIDFromInt(chanID) - if firstHop == roasbeefSongoku { - sourceKey, err := btcec.ParsePubKey( - sourceNode[:], btcec.S256(), - ) - if err != nil { - t.Fatal(err) + roasbeefSongoku := lnwire.NewShortChanIDFromInt(chanID) + if firstHop == roasbeefSongoku { + sourceKey, err := btcec.ParsePubKey( + sourceNode[:], btcec.S256(), + ) + if err != nil { + t.Fatal(err) + } + + return [32]byte{}, &htlcswitch.ForwardingError{ + ErrorSource: sourceKey, + + // Within our error, we'll add a channel update + // which is meant to reflect he new fee + // schedule for the node/channel. + FailureMessage: &lnwire.FailFeeInsufficient{ + Update: errChanUpdate, + }, + } } - return [32]byte{}, &htlcswitch.ForwardingError{ - ErrorSource: sourceKey, - - // Within our error, we'll add a channel update - // which is meant to reflect he new fee - // schedule for the node/channel. - FailureMessage: &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, - }, - } - } - - return preImage, nil - } + return preImage, nil + }) // Send off the payment request to the router, route through satoshi // should've been selected as a fall back and succeeded correctly. @@ -633,27 +632,27 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { // outgoing channel to son goku. Since this is a time lock related // error, we should fail the payment flow all together, as Goku is the // only channel to Sophon. - ctx.router.cfg.SendToSwitch = func(firstHop lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - if firstHop == roasbeefSongoku { - sourceKey, err := btcec.ParsePubKey( - sourceNode[:], btcec.S256(), - ) - if err != nil { - t.Fatal(err) + if firstHop == roasbeefSongoku { + sourceKey, err := btcec.ParsePubKey( + sourceNode[:], btcec.S256(), + ) + if err != nil { + t.Fatal(err) + } + + return [32]byte{}, &htlcswitch.ForwardingError{ + ErrorSource: sourceKey, + FailureMessage: &lnwire.FailExpiryTooSoon{ + Update: errChanUpdate, + }, + } } - return [32]byte{}, &htlcswitch.ForwardingError{ - ErrorSource: sourceKey, - FailureMessage: &lnwire.FailExpiryTooSoon{ - Update: errChanUpdate, - }, - } - } - - return preImage, nil - } + return preImage, nil + }) // assertExpectedPath is a helper function that asserts the returned // route properly routes around the failure we've introduced in the @@ -694,27 +693,27 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { // We'll now modify the error return an IncorrectCltvExpiry error // instead, this should result in the same behavior of roasbeef routing // around the faulty Son Goku node. - ctx.router.cfg.SendToSwitch = func(firstHop lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - if firstHop == roasbeefSongoku { - sourceKey, err := btcec.ParsePubKey( - sourceNode[:], btcec.S256(), - ) - if err != nil { - t.Fatal(err) + if firstHop == roasbeefSongoku { + sourceKey, err := btcec.ParsePubKey( + sourceNode[:], btcec.S256(), + ) + if err != nil { + t.Fatal(err) + } + + return [32]byte{}, &htlcswitch.ForwardingError{ + ErrorSource: sourceKey, + FailureMessage: &lnwire.FailIncorrectCltvExpiry{ + Update: errChanUpdate, + }, + } } - return [32]byte{}, &htlcswitch.ForwardingError{ - ErrorSource: sourceKey, - FailureMessage: &lnwire.FailIncorrectCltvExpiry{ - Update: errChanUpdate, - }, - } - } - - return preImage, nil - } + return preImage, nil + }) // Once again, Roasbeef should route around Goku since they disagree // w.r.t to the block height, and instead go through Pham Nuwen. @@ -771,40 +770,40 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // // TODO(roasbeef): filtering should be intelligent enough so just not // go through satoshi at all at this point. - ctx.router.cfg.SendToSwitch = func(firstHop lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - if firstHop == roasbeefLuoji { - // We'll first simulate an error from the first - // outgoing link to simulate the channel from luo ji to - // roasbeef not having enough capacity. - return [32]byte{}, &htlcswitch.ForwardingError{ - ErrorSource: sourcePub, - FailureMessage: &lnwire.FailTemporaryChannelFailure{}, - } - } - - // Next, we'll create an error from satoshi to indicate - // that the luoji node is not longer online, which should - // prune out the rest of the routes. - roasbeefSatoshi := lnwire.NewShortChanIDFromInt(2340213491) - if firstHop == roasbeefSatoshi { - vertex := ctx.aliases["satoshi"] - key, err := btcec.ParsePubKey( - vertex[:], btcec.S256(), - ) - if err != nil { - t.Fatal(err) + if firstHop == roasbeefLuoji { + // We'll first simulate an error from the first + // outgoing link to simulate the channel from luo ji to + // roasbeef not having enough capacity. + return [32]byte{}, &htlcswitch.ForwardingError{ + ErrorSource: sourcePub, + FailureMessage: &lnwire.FailTemporaryChannelFailure{}, + } } - return [32]byte{}, &htlcswitch.ForwardingError{ - ErrorSource: key, - FailureMessage: &lnwire.FailUnknownNextPeer{}, - } - } + // Next, we'll create an error from satoshi to indicate + // that the luoji node is not longer online, which should + // prune out the rest of the routes. + roasbeefSatoshi := lnwire.NewShortChanIDFromInt(2340213491) + if firstHop == roasbeefSatoshi { + vertex := ctx.aliases["satoshi"] + key, err := btcec.ParsePubKey( + vertex[:], btcec.S256(), + ) + if err != nil { + t.Fatal(err) + } - return preImage, nil - } + return [32]byte{}, &htlcswitch.ForwardingError{ + ErrorSource: key, + FailureMessage: &lnwire.FailUnknownNextPeer{}, + } + } + + return preImage, nil + }) ctx.router.missionControl.ResetHistory() @@ -826,18 +825,18 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // Next, we'll modify the SendToSwitch method to indicate that luo ji // wasn't originally online. This should also halt the send all // together as all paths contain luoji and he can't be reached. - ctx.router.cfg.SendToSwitch = func(firstHop lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - if firstHop == roasbeefLuoji { - return [32]byte{}, &htlcswitch.ForwardingError{ - ErrorSource: sourcePub, - FailureMessage: &lnwire.FailUnknownNextPeer{}, + if firstHop == roasbeefLuoji { + return [32]byte{}, &htlcswitch.ForwardingError{ + ErrorSource: sourcePub, + FailureMessage: &lnwire.FailUnknownNextPeer{}, + } } - } - return preImage, nil - } + return preImage, nil + }) // This shouldn't return an error, as we'll make a payment attempt via // the satoshi channel based on the assumption that there might be an @@ -869,20 +868,20 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { // Finally, we'll modify the SendToSwitch function to indicate that the // roasbeef -> luoji channel has insufficient capacity. This should // again cause us to instead go via the satoshi route. - ctx.router.cfg.SendToSwitch = func(firstHop lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { + ctx.router.cfg.Payer.(*mockPaymentAttemptDispatcher).setPaymentResult( + func(firstHop lnwire.ShortChannelID) ([32]byte, error) { - if firstHop == roasbeefLuoji { - // We'll first simulate an error from the first - // outgoing link to simulate the channel from luo ji to - // roasbeef not having enough capacity. - return [32]byte{}, &htlcswitch.ForwardingError{ - ErrorSource: sourcePub, - FailureMessage: &lnwire.FailTemporaryChannelFailure{}, + if firstHop == roasbeefLuoji { + // We'll first simulate an error from the first + // outgoing link to simulate the channel from luo ji to + // roasbeef not having enough capacity. + return [32]byte{}, &htlcswitch.ForwardingError{ + ErrorSource: sourcePub, + FailureMessage: &lnwire.FailTemporaryChannelFailure{}, + } } - } - return preImage, nil - } + return preImage, nil + }) paymentPreImage, rt, err = ctx.router.SendPayment(&payment) if err != nil { @@ -1525,13 +1524,10 @@ func TestWakeUpOnStaleBranch(t *testing.T) { // Create new router with same graph database. router, err := New(Config{ - Graph: ctx.graph, - Chain: ctx.chain, - ChainView: ctx.chainView, - SendToSwitch: func(_ lnwire.ShortChannelID, - _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { - return [32]byte{}, nil - }, + Graph: ctx.graph, + Chain: ctx.chain, + ChainView: ctx.chainView, + Payer: &mockPaymentAttemptDispatcher{}, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, }) @@ -2446,8 +2442,9 @@ func TestIsStaleEdgePolicy(t *testing.T) { func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { t.Parallel() + sessionKey, _ := btcec.NewPrivateKey(btcec.S256()) emptyRoute := &route.Route{} - _, _, err := generateSphinxPacket(emptyRoute, testHash[:]) + _, _, err := generateSphinxPacket(emptyRoute, testHash[:], sessionKey) if err != route.ErrNoRouteHopsProvided { t.Fatalf("expected empty hops error: instead got: %v", err) } diff --git a/server.go b/server.go index 72cbb1e3..cc2033ad 100644 --- a/server.go +++ b/server.go @@ -609,25 +609,18 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, } s.currentNodeAnn = nodeAnn + // The router will get access to the payment ID sequencer, such that it + // can generate unique payment IDs. + sequencer, err := htlcswitch.NewPersistentSequencer(chanDB) + if err != nil { + return nil, err + } + s.chanRouter, err = routing.New(routing.Config{ - Graph: chanGraph, - Chain: cc.chainIO, - ChainView: cc.chainView, - SendToSwitch: func(firstHop lnwire.ShortChannelID, - htlcAdd *lnwire.UpdateAddHTLC, - circuit *sphinx.Circuit) ([32]byte, error) { - - // Using the created circuit, initialize the error - // decrypter so we can parse+decode any failures - // incurred by this payment within the switch. - errorDecryptor := &htlcswitch.SphinxErrorDecrypter{ - OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), - } - - return s.htlcSwitch.SendHTLC( - firstHop, htlcAdd, errorDecryptor, - ) - }, + Graph: chanGraph, + Chain: cc.chainIO, + ChainView: cc.chainView, + Payer: s.htlcSwitch, ChannelPruneExpiry: routing.DefaultChannelPruneExpiry, GraphPruneInterval: time.Duration(time.Hour), QueryBandwidth: func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { @@ -660,6 +653,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, return link.Bandwidth() }, AssumeChannelValid: cfg.Routing.UseAssumeChannelValid(), + NextPaymentID: sequencer.NextID, }) if err != nil { return nil, fmt.Errorf("can't create router: %v", err)