htlcswitch/switch: use paymentResultStore to keep track of results

This commit is contained in:
Johan T. Halseth 2019-06-07 16:42:26 +02:00
parent 2dea790b55
commit 2cc778d309
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
2 changed files with 60 additions and 97 deletions

@ -3909,8 +3909,8 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
err = n.aliceServer.htlcSwitch.SendHTLC( err = n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc, n.firstBobChannelLink.ShortChanID(), pid, htlc,
) )
if err != ErrPaymentIDAlreadyExists { if err != ErrDuplicateAdd {
t.Fatalf("ErrPaymentIDAlreadyExists should have been "+ t.Fatalf("ErrDuplicateAdd should have been "+
"received got: %v", err) "received got: %v", err)
} }

@ -64,16 +64,6 @@ var (
zeroPreimage [sha256.Size]byte zeroPreimage [sha256.Size]byte
) )
// pendingPayment represents the payment which made by user and waits for
// updates to be received whether the payment has been rejected or proceed
// successfully.
type pendingPayment struct {
paymentHash lntypes.Hash
amount lnwire.MilliSatoshi
resultChan chan *networkResult
}
// plexPacket encapsulates switch packet and adds error channel to receive // plexPacket encapsulates switch packet and adds error channel to receive
// error from request handler. // error from request handler.
type plexPacket struct { type plexPacket struct {
@ -201,12 +191,12 @@ type Switch struct {
// service was initialized with. // service was initialized with.
cfg *Config cfg *Config
// pendingPayments stores payments initiated by the user that are not yet // networkResults stores the results of payments initiated by the user.
// settled. The map is used to later look up the payments and notify the // results. The store is used to later look up the payments and notify
// user of the result when they are complete. Each payment is given a unique // the user of the result when they are complete. Each payment attempt
// integer ID when it is created. // should be given a unique integer ID when it is created, otherwise
pendingPayments map[uint64]*pendingPayment // results might be overwritten.
pendingMutex sync.RWMutex networkResults *networkResultStore
// circuits is storage for payment circuits which are used to // circuits is storage for payment circuits which are used to
// forward the settle/fail htlc updates back to the add htlc initiator. // forward the settle/fail htlc updates back to the add htlc initiator.
@ -292,7 +282,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
interfaceIndex: make(map[[33]byte]map[lnwire.ChannelID]ChannelLink), interfaceIndex: make(map[[33]byte]map[lnwire.ChannelID]ChannelLink),
pendingLinkIndex: make(map[lnwire.ChannelID]ChannelLink), pendingLinkIndex: make(map[lnwire.ChannelID]ChannelLink),
pendingPayments: make(map[uint64]*pendingPayment), networkResults: newNetworkResultStore(cfg.DB),
htlcPlex: make(chan *plexPacket), htlcPlex: make(chan *plexPacket),
chanCloseRequests: make(chan *ChanClose), chanCloseRequests: make(chan *ChanClose),
resolutionMsgs: make(chan *resolutionMsg), resolutionMsgs: make(chan *resolutionMsg),
@ -345,12 +335,33 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro
func (s *Switch) GetPaymentResult(paymentID uint64, paymentHash lntypes.Hash, func (s *Switch) GetPaymentResult(paymentID uint64, paymentHash lntypes.Hash,
deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) { deobfuscator ErrorDecrypter) (<-chan *PaymentResult, error) {
s.pendingMutex.Lock() var (
payment, ok := s.pendingPayments[paymentID] nChan <-chan *networkResult
s.pendingMutex.Unlock() err error
outKey = CircuitKey{
ChanID: sourceHop,
HtlcID: paymentID,
}
)
if !ok { // If the payment is not found in the circuit map, check whether a
return nil, ErrPaymentIDNotFound // result is already available.
// Assumption: no one will add this payment ID other than the caller.
if s.circuits.LookupCircuit(outKey) == nil {
res, err := s.networkResults.getResult(paymentID)
if err != nil {
return nil, err
}
c := make(chan *networkResult, 1)
c <- res
nChan = c
} else {
// The payment was committed to the circuits, subscribe for a
// result.
nChan, err = s.networkResults.subscribeResult(paymentID)
if err != nil {
return nil, err
}
} }
resultChan := make(chan *PaymentResult, 1) resultChan := make(chan *PaymentResult, 1)
@ -364,7 +375,7 @@ func (s *Switch) GetPaymentResult(paymentID uint64, paymentHash lntypes.Hash,
var n *networkResult var n *networkResult
select { select {
case n = <-payment.resultChan: case n = <-nChan:
case <-s.quit: case <-s.quit:
// We close the result channel to signal a shutdown. We // We close the result channel to signal a shutdown. We
// don't send any result in this case since the HTLC is // don't send any result in this case since the HTLC is
@ -398,24 +409,6 @@ func (s *Switch) GetPaymentResult(paymentID uint64, paymentHash lntypes.Hash,
func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
htlc *lnwire.UpdateAddHTLC) error { htlc *lnwire.UpdateAddHTLC) error {
// Create payment and add to the map of payment in order later to be
// able to retrieve it and return response to the user.
payment := &pendingPayment{
resultChan: make(chan *networkResult, 1),
paymentHash: htlc.PaymentHash,
amount: htlc.Amount,
}
s.pendingMutex.Lock()
if _, ok := s.pendingPayments[paymentID]; ok {
s.pendingMutex.Unlock()
return ErrPaymentIDAlreadyExists
}
s.pendingPayments[paymentID] = payment
s.pendingMutex.Unlock()
// Generate and send new update packet, if error will be received on // Generate and send new update packet, if error will be received on
// this stage it means that packet haven't left boundaries of our // this stage it means that packet haven't left boundaries of our
// system and something wrong happened. // system and something wrong happened.
@ -426,12 +419,7 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
htlc: htlc, htlc: htlc,
} }
if err := s.forward(packet); err != nil { return s.forward(packet)
s.removePendingPayment(paymentID)
return err
}
return nil
} }
// UpdateForwardingPolicies sends a message to the switch to update the // UpdateForwardingPolicies sends a message to the switch to update the
@ -856,15 +844,34 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
// multiple db transactions. The guarantees of the circuit map are stringent // multiple db transactions. The guarantees of the circuit map are stringent
// enough such that we are able to tolerate reordering of these operations // enough such that we are able to tolerate reordering of these operations
// without side effects. The primary operations handled are: // without side effects. The primary operations handled are:
// 1. Ack settle/fail references, to avoid resending this response internally // 1. Save the payment result to the pending payment store.
// 2. Teardown the closing circuit in the circuit map // 2. Notify subscribers about the payment result.
// 3. Transition the payment status to grounded or completed. // 3. Ack settle/fail references, to avoid resending this response internally
// 4. Respond to an in-mem pending payment, if it is found. // 4. Teardown the closing circuit in the circuit map
// //
// NOTE: This method MUST be spawned as a goroutine. // NOTE: This method MUST be spawned as a goroutine.
func (s *Switch) handleLocalResponse(pkt *htlcPacket) { func (s *Switch) handleLocalResponse(pkt *htlcPacket) {
defer s.wg.Done() defer s.wg.Done()
paymentID := pkt.incomingHTLCID
// The error reason will be unencypted in case this a local
// failure or a converted error.
unencrypted := pkt.localFailure || pkt.convertedError
n := &networkResult{
msg: pkt.htlc,
unencrypted: unencrypted,
isResolution: pkt.isResolution,
}
// Store the result to the db. This will also notify subscribers about
// the result.
if err := s.networkResults.storeResult(paymentID, n); err != nil {
log.Errorf("Unable to complete payment for pid=%v: %v",
paymentID, err)
return
}
// First, we'll clean up any fwdpkg references, circuit entries, and // First, we'll clean up any fwdpkg references, circuit entries, and
// mark in our db that the payment for this payment hash has either // mark in our db that the payment for this payment hash has either
// succeeded or failed. // succeeded or failed.
@ -892,26 +899,6 @@ func (s *Switch) handleLocalResponse(pkt *htlcPacket) {
pkt.inKey(), err) pkt.inKey(), err)
return return
} }
// Locate the pending payment to notify the application that this
// payment has failed. If one is not found, it likely means the daemon
// has been restarted since sending the payment.
payment := s.findPayment(pkt.incomingHTLCID)
// The error reason will be unencypted in case this a local
// failure or a converted error.
unencrypted := pkt.localFailure || pkt.convertedError
n := &networkResult{
msg: pkt.htlc,
unencrypted: unencrypted,
isResolution: pkt.isResolution,
}
// Deliver the payment error and preimage to the application, if it is
// waiting for a response.
if payment != nil {
payment.resultChan <- n
}
} }
// extractResult uses the given deobfuscator to extract the payment result from // extractResult uses the given deobfuscator to extract the payment result from
@ -2173,30 +2160,6 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
return channelLinks, nil return channelLinks, nil
} }
// removePendingPayment is the helper function which removes the pending user
// payment.
func (s *Switch) removePendingPayment(paymentID uint64) {
s.pendingMutex.Lock()
defer s.pendingMutex.Unlock()
delete(s.pendingPayments, paymentID)
}
// findPayment is the helper function which find the payment.
func (s *Switch) findPayment(paymentID uint64) *pendingPayment {
s.pendingMutex.RLock()
defer s.pendingMutex.RUnlock()
payment, ok := s.pendingPayments[paymentID]
if !ok {
log.Errorf("Cannot find pending payment with ID %d",
paymentID)
return nil
}
return payment
}
// CircuitModifier returns a reference to subset of the interfaces provided by // CircuitModifier returns a reference to subset of the interfaces provided by
// the circuit map, to allow links to open and close circuits. // the circuit map, to allow links to open and close circuits.
func (s *Switch) CircuitModifier() CircuitModifier { func (s *Switch) CircuitModifier() CircuitModifier {