diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 82600b4e..7b9f7486 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -721,19 +721,16 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { "local_log_index=%v, batch_size=%v", htlc.PaymentHash[:], index, l.batchCounter+1) - // If packet was forwarded from another channel link then we should - // create circuit (remember the path) in order to forward settle/fail + // Create circuit (remember the path) in order to forward settle/fail // packet back. - if pkt.incomingChanID != (lnwire.ShortChannelID{}) { - l.cfg.Switch.addCircuit(&PaymentCircuit{ - PaymentHash: htlc.PaymentHash, - IncomingChanID: pkt.incomingChanID, - IncomingHTLCID: pkt.incomingHTLCID, - OutgoingChanID: l.ShortChanID(), - OutgoingHTLCID: index, - ErrorEncrypter: pkt.obfuscator, - }) - } + l.cfg.Switch.addCircuit(&PaymentCircuit{ + PaymentHash: htlc.PaymentHash, + IncomingChanID: pkt.incomingChanID, + IncomingHTLCID: pkt.incomingHTLCID, + OutgoingChanID: l.ShortChanID(), + OutgoingHTLCID: index, + ErrorEncrypter: pkt.obfuscator, + }) htlc.ID = index l.cfg.Peer.SendMessage(htlc) diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 636518d7..635b940c 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -1448,7 +1448,7 @@ func newSingleLinkTestHarness(chanAmt btcutil.Amount) (ChannelLink, func(), erro aliceCfg := ChannelLinkConfig{ FwrdingPolicy: globalPolicy, Peer: &alicePeer, - Switch: nil, + Switch: New(Config{}), DecodeHopIterator: decoder.DecodeHopIterator, DecodeOnionObfuscator: func(io.Reader) (ErrorEncrypter, lnwire.FailCode) { return obfuscator, lnwire.CodeNone diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 1219b24b..6524af9a 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -121,11 +121,13 @@ type Switch struct { // service was initialized with. cfg *Config - // pendingPayments is correspondence of user payments and its hashes, - // which is used to save the payments which made by user and notify - // them about result later. - pendingPayments map[lnwallet.PaymentHash][]*pendingPayment + // pendingPayments stores payments initiated by the user that are not yet + // settled. The map is used to later look up the payments and notify the + // user of the result when they are complete. Each payment is given a unique + // integer ID when it is created. + pendingPayments map[uint64]*pendingPayment pendingMutex sync.RWMutex + nextPendingID uint64 // circuits is storage for payment circuits which are used to // forward the settle/fail htlc updates back to the add htlc initiator. @@ -171,7 +173,7 @@ func New(cfg Config) *Switch { linkIndex: make(map[lnwire.ChannelID]ChannelLink), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}), - pendingPayments: make(map[lnwallet.PaymentHash][]*pendingPayment), + pendingPayments: make(map[uint64]*pendingPayment), htlcPlex: make(chan *plexPacket), chanCloseRequests: make(chan *ChanClose), linkControl: make(chan interface{}), @@ -195,19 +197,21 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, } s.pendingMutex.Lock() - s.pendingPayments[htlc.PaymentHash] = append( - s.pendingPayments[htlc.PaymentHash], payment) + paymentID := s.nextPendingID + s.nextPendingID++ + s.pendingPayments[paymentID] = payment s.pendingMutex.Unlock() // Generate and send new update packet, if error will be received on // this stage it means that packet haven't left boundaries of our // system and something wrong happened. packet := &htlcPacket{ - destNode: nextNode, - htlc: htlc, + incomingHTLCID: paymentID, + destNode: nextNode, + htlc: htlc, } if err := s.forward(packet); err != nil { - s.removePendingPayment(payment.amount, payment.paymentHash) + s.removePendingPayment(paymentID) return zeroPreimage, err } @@ -345,7 +349,16 @@ func (s *Switch) forward(packet *htlcPacket) error { // o <-settle-- o <--settle-- o // Alice Bob Carol // -func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket) error { +func (s *Switch) handleLocalDispatch(packet *htlcPacket) error { + // Pending payments use a special interpretation of the incomingChanID and + // incomingHTLCID fields on packet where the channel ID is blank and the + // HTLC ID is the payment ID. The switch basically views the users of the + // node as a special channel that also offers a sequence of HTLCs. + payment, err := s.findPayment(packet.incomingHTLCID) + if err != nil { + return err + } + switch htlc := packet.htlc.(type) { // User have created the htlc update therefore we should find the @@ -407,6 +420,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket // manages then channel. // // TODO(roasbeef): should return with an error + packet.outgoingChanID = destination.ShortChanID() destination.HandleSwitchPacket(packet) return nil @@ -416,7 +430,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket // Notify the user that his payment was successfully proceed. payment.err <- nil payment.preimage <- htlc.PaymentPreimage - s.removePendingPayment(payment.amount, payment.paymentHash) + s.removePendingPayment(packet.incomingHTLCID) // We've just received a fail update which means we can finalize the // user payment and return fail response. @@ -439,7 +453,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket } payment.preimage <- zeroPreimage - s.removePendingPayment(payment.amount, payment.paymentHash) + s.removePendingPayment(packet.incomingHTLCID) default: return errors.New("wrong update type") @@ -458,6 +472,12 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // payment circuit within our internal state so we can properly forward // the ultimate settle message back latter. case *lnwire.UpdateAddHTLC: + if packet.incomingChanID == (lnwire.ShortChannelID{}) { + // A blank incomingChanID indicates that this is a pending + // user-initiated payment. + return s.handleLocalDispatch(packet) + } + source, err := s.getLinkByShortID(packet.incomingChanID) if err != nil { err := errors.Errorf("unable to find channel link "+ @@ -581,15 +601,21 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { circuit.OutgoingChanID) } + packet.incomingChanID = circuit.IncomingChanID + packet.incomingHTLCID = circuit.IncomingHTLCID + + // A blank IncomingChanID in a circuit indicates that it is a + // pending user-initiated payment. + if circuit.IncomingChanID == (lnwire.ShortChannelID{}) { + return s.handleLocalDispatch(packet) + } + // Obfuscate the error message for fail updates before sending back // through the circuit. if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated { htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt( htlc.Reason) } - - packet.incomingChanID = circuit.IncomingChanID - packet.incomingHTLCID = circuit.IncomingHTLCID } source, err := s.getLinkByShortID(packet.incomingChanID) @@ -696,37 +722,7 @@ func (s *Switch) htlcForwarder() { // packet concretely, then either forward it along, or // interpret a return packet to a locally initialized one. case cmd := <-s.htlcPlex: - var ( - paymentHash lnwallet.PaymentHash - amount lnwire.MilliSatoshi - ) - - // Only three types of message should be forwarded: - // add, fails, and settles. Anything else is an error. - switch m := cmd.pkt.htlc.(type) { - case *lnwire.UpdateAddHTLC: - paymentHash = m.PaymentHash - amount = m.Amount - case *lnwire.UpdateFufillHTLC, *lnwire.UpdateFailHTLC: - paymentHash = cmd.pkt.payHash - amount = cmd.pkt.amount - default: - cmd.err <- errors.New("wrong type of update") - return - } - - // If we can locate this packet in our local records, - // then this means a local sub-system initiated it. - // Otherwise, this is just a packet to be forwarded, so - // we'll treat it as so. - // - // TODO(roasbeef): can fast path this - payment, err := s.findPayment(amount, paymentHash) - if err != nil { - cmd.err <- s.handlePacketForward(cmd.pkt) - } else { - cmd.err <- s.handleLocalDispatch(payment, cmd.pkt) - } + cmd.err <- s.handlePacketForward(cmd.pkt) // The log ticker has fired, so we'll calculate some forwarding // stats for the last 10 seconds to display within the logs to @@ -1034,64 +1030,36 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) { // removePendingPayment is the helper function which removes the pending user // payment. -func (s *Switch) removePendingPayment(amount lnwire.MilliSatoshi, - hash lnwallet.PaymentHash) error { - +func (s *Switch) removePendingPayment(paymentID uint64) error { s.pendingMutex.Lock() defer s.pendingMutex.Unlock() - payments, ok := s.pendingPayments[hash] - if ok { - for i, payment := range payments { - if payment.amount == amount { - // Delete without preserving order - // Google: Golang slice tricks - payments[i] = payments[len(payments)-1] - payments[len(payments)-1] = nil - s.pendingPayments[hash] = payments[:len(payments)-1] - - if len(s.pendingPayments[hash]) == 0 { - delete(s.pendingPayments, hash) - } - - return nil - } - } + if _, ok := s.pendingPayments[paymentID]; !ok { + return errors.Errorf("Cannot find pending payment with ID %d", + paymentID) } - return errors.Errorf("unable to remove pending payment with "+ - "hash(%v) and amount(%v)", hash, amount) + delete(s.pendingPayments, paymentID) + return nil } // findPayment is the helper function which find the payment. -func (s *Switch) findPayment(amount lnwire.MilliSatoshi, - hash lnwallet.PaymentHash) (*pendingPayment, error) { - +func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) { s.pendingMutex.RLock() defer s.pendingMutex.RUnlock() - payments, ok := s.pendingPayments[hash] - if ok { - for _, payment := range payments { - if payment.amount == amount { - return payment, nil - } - } + payment, ok := s.pendingPayments[paymentID] + if !ok { + return nil, errors.Errorf("Cannot find pending payment with ID %d", + paymentID) } - - return nil, errors.Errorf("unable to remove pending payment with "+ - "hash(%v) and amount(%v)", hash, amount) + return payment, nil } // numPendingPayments is helper function which returns the overall number of // pending user payments. func (s *Switch) numPendingPayments() int { - var l int - for _, payments := range s.pendingPayments { - l += len(payments) - } - - return l + return len(s.pendingPayments) } // addCircuit adds a circuit to the switch's in-memory mapping.