htlcswitch: Assign each pending payment a unique ID.

This simplifies the pending payment handling code because it allows it
be handled in nearly the same way as forwarded HTLCs by treating an
empty channel ID as local dispatch.
This commit is contained in:
Jim Posen 2017-10-30 12:57:32 -07:00 committed by Olaoluwa Osuntokun
parent 4a29fbdab2
commit 40fb0ddcfc
3 changed files with 66 additions and 101 deletions

@ -721,19 +721,16 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
"local_log_index=%v, batch_size=%v",
htlc.PaymentHash[:], index, l.batchCounter+1)
// If packet was forwarded from another channel link then we should
// create circuit (remember the path) in order to forward settle/fail
// Create circuit (remember the path) in order to forward settle/fail
// packet back.
if pkt.incomingChanID != (lnwire.ShortChannelID{}) {
l.cfg.Switch.addCircuit(&PaymentCircuit{
PaymentHash: htlc.PaymentHash,
IncomingChanID: pkt.incomingChanID,
IncomingHTLCID: pkt.incomingHTLCID,
OutgoingChanID: l.ShortChanID(),
OutgoingHTLCID: index,
ErrorEncrypter: pkt.obfuscator,
})
}
l.cfg.Switch.addCircuit(&PaymentCircuit{
PaymentHash: htlc.PaymentHash,
IncomingChanID: pkt.incomingChanID,
IncomingHTLCID: pkt.incomingHTLCID,
OutgoingChanID: l.ShortChanID(),
OutgoingHTLCID: index,
ErrorEncrypter: pkt.obfuscator,
})
htlc.ID = index
l.cfg.Peer.SendMessage(htlc)

@ -1448,7 +1448,7 @@ func newSingleLinkTestHarness(chanAmt btcutil.Amount) (ChannelLink, func(), erro
aliceCfg := ChannelLinkConfig{
FwrdingPolicy: globalPolicy,
Peer: &alicePeer,
Switch: nil,
Switch: New(Config{}),
DecodeHopIterator: decoder.DecodeHopIterator,
DecodeOnionObfuscator: func(io.Reader) (ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone

@ -121,11 +121,13 @@ type Switch struct {
// service was initialized with.
cfg *Config
// pendingPayments is correspondence of user payments and its hashes,
// which is used to save the payments which made by user and notify
// them about result later.
pendingPayments map[lnwallet.PaymentHash][]*pendingPayment
// pendingPayments stores payments initiated by the user that are not yet
// settled. The map is used to later look up the payments and notify the
// user of the result when they are complete. Each payment is given a unique
// integer ID when it is created.
pendingPayments map[uint64]*pendingPayment
pendingMutex sync.RWMutex
nextPendingID uint64
// circuits is storage for payment circuits which are used to
// forward the settle/fail htlc updates back to the add htlc initiator.
@ -171,7 +173,7 @@ func New(cfg Config) *Switch {
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}),
pendingPayments: make(map[lnwallet.PaymentHash][]*pendingPayment),
pendingPayments: make(map[uint64]*pendingPayment),
htlcPlex: make(chan *plexPacket),
chanCloseRequests: make(chan *ChanClose),
linkControl: make(chan interface{}),
@ -195,19 +197,21 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
}
s.pendingMutex.Lock()
s.pendingPayments[htlc.PaymentHash] = append(
s.pendingPayments[htlc.PaymentHash], payment)
paymentID := s.nextPendingID
s.nextPendingID++
s.pendingPayments[paymentID] = payment
s.pendingMutex.Unlock()
// Generate and send new update packet, if error will be received on
// this stage it means that packet haven't left boundaries of our
// system and something wrong happened.
packet := &htlcPacket{
destNode: nextNode,
htlc: htlc,
incomingHTLCID: paymentID,
destNode: nextNode,
htlc: htlc,
}
if err := s.forward(packet); err != nil {
s.removePendingPayment(payment.amount, payment.paymentHash)
s.removePendingPayment(paymentID)
return zeroPreimage, err
}
@ -345,7 +349,16 @@ func (s *Switch) forward(packet *htlcPacket) error {
// o <-settle-- o <--settle-- o
// Alice Bob Carol
//
func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket) error {
func (s *Switch) handleLocalDispatch(packet *htlcPacket) error {
// Pending payments use a special interpretation of the incomingChanID and
// incomingHTLCID fields on packet where the channel ID is blank and the
// HTLC ID is the payment ID. The switch basically views the users of the
// node as a special channel that also offers a sequence of HTLCs.
payment, err := s.findPayment(packet.incomingHTLCID)
if err != nil {
return err
}
switch htlc := packet.htlc.(type) {
// User have created the htlc update therefore we should find the
@ -407,6 +420,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
// manages then channel.
//
// TODO(roasbeef): should return with an error
packet.outgoingChanID = destination.ShortChanID()
destination.HandleSwitchPacket(packet)
return nil
@ -416,7 +430,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
// Notify the user that his payment was successfully proceed.
payment.err <- nil
payment.preimage <- htlc.PaymentPreimage
s.removePendingPayment(payment.amount, payment.paymentHash)
s.removePendingPayment(packet.incomingHTLCID)
// We've just received a fail update which means we can finalize the
// user payment and return fail response.
@ -439,7 +453,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
}
payment.preimage <- zeroPreimage
s.removePendingPayment(payment.amount, payment.paymentHash)
s.removePendingPayment(packet.incomingHTLCID)
default:
return errors.New("wrong update type")
@ -458,6 +472,12 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
// payment circuit within our internal state so we can properly forward
// the ultimate settle message back latter.
case *lnwire.UpdateAddHTLC:
if packet.incomingChanID == (lnwire.ShortChannelID{}) {
// A blank incomingChanID indicates that this is a pending
// user-initiated payment.
return s.handleLocalDispatch(packet)
}
source, err := s.getLinkByShortID(packet.incomingChanID)
if err != nil {
err := errors.Errorf("unable to find channel link "+
@ -581,15 +601,21 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
circuit.OutgoingChanID)
}
packet.incomingChanID = circuit.IncomingChanID
packet.incomingHTLCID = circuit.IncomingHTLCID
// A blank IncomingChanID in a circuit indicates that it is a
// pending user-initiated payment.
if circuit.IncomingChanID == (lnwire.ShortChannelID{}) {
return s.handleLocalDispatch(packet)
}
// Obfuscate the error message for fail updates before sending back
// through the circuit.
if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated {
htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
htlc.Reason)
}
packet.incomingChanID = circuit.IncomingChanID
packet.incomingHTLCID = circuit.IncomingHTLCID
}
source, err := s.getLinkByShortID(packet.incomingChanID)
@ -696,37 +722,7 @@ func (s *Switch) htlcForwarder() {
// packet concretely, then either forward it along, or
// interpret a return packet to a locally initialized one.
case cmd := <-s.htlcPlex:
var (
paymentHash lnwallet.PaymentHash
amount lnwire.MilliSatoshi
)
// Only three types of message should be forwarded:
// add, fails, and settles. Anything else is an error.
switch m := cmd.pkt.htlc.(type) {
case *lnwire.UpdateAddHTLC:
paymentHash = m.PaymentHash
amount = m.Amount
case *lnwire.UpdateFufillHTLC, *lnwire.UpdateFailHTLC:
paymentHash = cmd.pkt.payHash
amount = cmd.pkt.amount
default:
cmd.err <- errors.New("wrong type of update")
return
}
// If we can locate this packet in our local records,
// then this means a local sub-system initiated it.
// Otherwise, this is just a packet to be forwarded, so
// we'll treat it as so.
//
// TODO(roasbeef): can fast path this
payment, err := s.findPayment(amount, paymentHash)
if err != nil {
cmd.err <- s.handlePacketForward(cmd.pkt)
} else {
cmd.err <- s.handleLocalDispatch(payment, cmd.pkt)
}
cmd.err <- s.handlePacketForward(cmd.pkt)
// The log ticker has fired, so we'll calculate some forwarding
// stats for the last 10 seconds to display within the logs to
@ -1034,64 +1030,36 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
// removePendingPayment is the helper function which removes the pending user
// payment.
func (s *Switch) removePendingPayment(amount lnwire.MilliSatoshi,
hash lnwallet.PaymentHash) error {
func (s *Switch) removePendingPayment(paymentID uint64) error {
s.pendingMutex.Lock()
defer s.pendingMutex.Unlock()
payments, ok := s.pendingPayments[hash]
if ok {
for i, payment := range payments {
if payment.amount == amount {
// Delete without preserving order
// Google: Golang slice tricks
payments[i] = payments[len(payments)-1]
payments[len(payments)-1] = nil
s.pendingPayments[hash] = payments[:len(payments)-1]
if len(s.pendingPayments[hash]) == 0 {
delete(s.pendingPayments, hash)
}
return nil
}
}
if _, ok := s.pendingPayments[paymentID]; !ok {
return errors.Errorf("Cannot find pending payment with ID %d",
paymentID)
}
return errors.Errorf("unable to remove pending payment with "+
"hash(%v) and amount(%v)", hash, amount)
delete(s.pendingPayments, paymentID)
return nil
}
// findPayment is the helper function which find the payment.
func (s *Switch) findPayment(amount lnwire.MilliSatoshi,
hash lnwallet.PaymentHash) (*pendingPayment, error) {
func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) {
s.pendingMutex.RLock()
defer s.pendingMutex.RUnlock()
payments, ok := s.pendingPayments[hash]
if ok {
for _, payment := range payments {
if payment.amount == amount {
return payment, nil
}
}
payment, ok := s.pendingPayments[paymentID]
if !ok {
return nil, errors.Errorf("Cannot find pending payment with ID %d",
paymentID)
}
return nil, errors.Errorf("unable to remove pending payment with "+
"hash(%v) and amount(%v)", hash, amount)
return payment, nil
}
// numPendingPayments is helper function which returns the overall number of
// pending user payments.
func (s *Switch) numPendingPayments() int {
var l int
for _, payments := range s.pendingPayments {
l += len(payments)
}
return l
return len(s.pendingPayments)
}
// addCircuit adds a circuit to the switch's in-memory mapping.