diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index c0cb310b..4df57cec 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -805,10 +805,30 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { return link.HandleSwitchPacket(pkt) } - // Otherwise this is a response to a payment that we initiated. We'll - // clean up any fwdpkg references, circuit entries, and mark in our db - // that the payment for this payment hash has either succeeded or - // failed. + s.wg.Add(1) + go s.handleLocalResponse(pkt) + + return nil +} + +// handleLocalResponse processes a Settle or Fail responding to a +// locally-initiated payment. This is handled asynchronously to avoid blocking +// the main event loop within the switch, as these operations can require +// multiple db transactions. The guarantees of the circuit map are stringent +// enough such that we are able to tolerate reordering of these operations +// without side effects. The primary operations handled are: +// 1. Ack settle/fail references, to avoid resending this response internally +// 2. Teardown the closing circuit in the circuit map +// 3. Transition the payment status to grounded or completed. +// 4. Respond to an in-mem pending payment, if it is found. +// +// NOTE: This method MUST be spawned as a goroutine. +func (s *Switch) handleLocalResponse(pkt *htlcPacket) { + defer s.wg.Done() + + // First, we'll clean up any fwdpkg references, circuit entries, and + // mark in our db that the payment for this payment hash has either + // succeeded or failed. // // If this response is contained in a forwarding package, we'll start by // acking the settle/fail so that we don't continue to retransmit the @@ -817,7 +837,7 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { if err := s.ackSettleFail(*pkt.destRef); err != nil { log.Warnf("Unable to ack settle/fail reference: %s: %v", *pkt.destRef, err) - return err + return } } @@ -831,7 +851,7 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { if err := s.teardownCircuit(pkt); err != nil { log.Warnf("Unable to teardown circuit %s: %v", pkt.inKey(), err) - return err + return } // Locate the pending payment to notify the application that this @@ -854,7 +874,9 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { // payment to this hash. err := s.control.Success(pkt.circuit.PaymentHash) if err != nil && err != ErrPaymentAlreadyCompleted { - return err + log.Warnf("Unable to mark completed payment %x: %v", + pkt.circuit.PaymentHash, err) + return } preimage = htlc.PaymentPreimage @@ -867,13 +889,16 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { // payment. err := s.control.Fail(pkt.circuit.PaymentHash) if err != nil && err != ErrPaymentAlreadyCompleted { - return err + log.Warnf("Unable to ground payment %x: %v", + pkt.circuit.PaymentHash, err) + return } paymentErr = s.parseFailedPayment(payment, pkt, htlc) default: - return errors.New("wrong update type") + log.Warnf("Received unknown response type: %T", pkt.htlc) + return } // Deliver the payment error and preimage to the application, if it is @@ -883,8 +908,6 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { payment.preimage <- preimage s.removePendingPayment(pkt.incomingHTLCID) } - - return nil } // parseFailedPayment determines the appropriate failure message to return to @@ -2078,17 +2101,11 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) { // removePendingPayment is the helper function which removes the pending user // payment. -func (s *Switch) removePendingPayment(paymentID uint64) error { +func (s *Switch) removePendingPayment(paymentID uint64) { s.pendingMutex.Lock() defer s.pendingMutex.Unlock() - if _, ok := s.pendingPayments[paymentID]; !ok { - return fmt.Errorf("Cannot find pending payment with ID %d", - paymentID) - } - delete(s.pendingPayments, paymentID) - return nil } // findPayment is the helper function which find the payment. @@ -2115,6 +2132,9 @@ func (s *Switch) CircuitModifier() CircuitModifier { // numPendingPayments is helper function which returns the overall number of // pending user payments. func (s *Switch) numPendingPayments() int { + s.pendingMutex.RLock() + defer s.pendingMutex.RUnlock() + return len(s.pendingPayments) }