htlcswitch: Assign each pending payment a unique ID.
This simplifies the pending payment handling code because it allows it be handled in nearly the same way as forwarded HTLCs by treating an empty channel ID as local dispatch.
This commit is contained in:
parent
4a29fbdab2
commit
40fb0ddcfc
@ -721,19 +721,16 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
|
||||
"local_log_index=%v, batch_size=%v",
|
||||
htlc.PaymentHash[:], index, l.batchCounter+1)
|
||||
|
||||
// If packet was forwarded from another channel link then we should
|
||||
// create circuit (remember the path) in order to forward settle/fail
|
||||
// Create circuit (remember the path) in order to forward settle/fail
|
||||
// packet back.
|
||||
if pkt.incomingChanID != (lnwire.ShortChannelID{}) {
|
||||
l.cfg.Switch.addCircuit(&PaymentCircuit{
|
||||
PaymentHash: htlc.PaymentHash,
|
||||
IncomingChanID: pkt.incomingChanID,
|
||||
IncomingHTLCID: pkt.incomingHTLCID,
|
||||
OutgoingChanID: l.ShortChanID(),
|
||||
OutgoingHTLCID: index,
|
||||
ErrorEncrypter: pkt.obfuscator,
|
||||
})
|
||||
}
|
||||
l.cfg.Switch.addCircuit(&PaymentCircuit{
|
||||
PaymentHash: htlc.PaymentHash,
|
||||
IncomingChanID: pkt.incomingChanID,
|
||||
IncomingHTLCID: pkt.incomingHTLCID,
|
||||
OutgoingChanID: l.ShortChanID(),
|
||||
OutgoingHTLCID: index,
|
||||
ErrorEncrypter: pkt.obfuscator,
|
||||
})
|
||||
|
||||
htlc.ID = index
|
||||
l.cfg.Peer.SendMessage(htlc)
|
||||
|
@ -1448,7 +1448,7 @@ func newSingleLinkTestHarness(chanAmt btcutil.Amount) (ChannelLink, func(), erro
|
||||
aliceCfg := ChannelLinkConfig{
|
||||
FwrdingPolicy: globalPolicy,
|
||||
Peer: &alicePeer,
|
||||
Switch: nil,
|
||||
Switch: New(Config{}),
|
||||
DecodeHopIterator: decoder.DecodeHopIterator,
|
||||
DecodeOnionObfuscator: func(io.Reader) (ErrorEncrypter, lnwire.FailCode) {
|
||||
return obfuscator, lnwire.CodeNone
|
||||
|
@ -121,11 +121,13 @@ type Switch struct {
|
||||
// service was initialized with.
|
||||
cfg *Config
|
||||
|
||||
// pendingPayments is correspondence of user payments and its hashes,
|
||||
// which is used to save the payments which made by user and notify
|
||||
// them about result later.
|
||||
pendingPayments map[lnwallet.PaymentHash][]*pendingPayment
|
||||
// pendingPayments stores payments initiated by the user that are not yet
|
||||
// settled. The map is used to later look up the payments and notify the
|
||||
// user of the result when they are complete. Each payment is given a unique
|
||||
// integer ID when it is created.
|
||||
pendingPayments map[uint64]*pendingPayment
|
||||
pendingMutex sync.RWMutex
|
||||
nextPendingID uint64
|
||||
|
||||
// circuits is storage for payment circuits which are used to
|
||||
// forward the settle/fail htlc updates back to the add htlc initiator.
|
||||
@ -171,7 +173,7 @@ func New(cfg Config) *Switch {
|
||||
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
|
||||
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
|
||||
interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}),
|
||||
pendingPayments: make(map[lnwallet.PaymentHash][]*pendingPayment),
|
||||
pendingPayments: make(map[uint64]*pendingPayment),
|
||||
htlcPlex: make(chan *plexPacket),
|
||||
chanCloseRequests: make(chan *ChanClose),
|
||||
linkControl: make(chan interface{}),
|
||||
@ -195,19 +197,21 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
|
||||
}
|
||||
|
||||
s.pendingMutex.Lock()
|
||||
s.pendingPayments[htlc.PaymentHash] = append(
|
||||
s.pendingPayments[htlc.PaymentHash], payment)
|
||||
paymentID := s.nextPendingID
|
||||
s.nextPendingID++
|
||||
s.pendingPayments[paymentID] = payment
|
||||
s.pendingMutex.Unlock()
|
||||
|
||||
// Generate and send new update packet, if error will be received on
|
||||
// this stage it means that packet haven't left boundaries of our
|
||||
// system and something wrong happened.
|
||||
packet := &htlcPacket{
|
||||
destNode: nextNode,
|
||||
htlc: htlc,
|
||||
incomingHTLCID: paymentID,
|
||||
destNode: nextNode,
|
||||
htlc: htlc,
|
||||
}
|
||||
if err := s.forward(packet); err != nil {
|
||||
s.removePendingPayment(payment.amount, payment.paymentHash)
|
||||
s.removePendingPayment(paymentID)
|
||||
return zeroPreimage, err
|
||||
}
|
||||
|
||||
@ -345,7 +349,16 @@ func (s *Switch) forward(packet *htlcPacket) error {
|
||||
// o <-settle-- o <--settle-- o
|
||||
// Alice Bob Carol
|
||||
//
|
||||
func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket) error {
|
||||
func (s *Switch) handleLocalDispatch(packet *htlcPacket) error {
|
||||
// Pending payments use a special interpretation of the incomingChanID and
|
||||
// incomingHTLCID fields on packet where the channel ID is blank and the
|
||||
// HTLC ID is the payment ID. The switch basically views the users of the
|
||||
// node as a special channel that also offers a sequence of HTLCs.
|
||||
payment, err := s.findPayment(packet.incomingHTLCID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch htlc := packet.htlc.(type) {
|
||||
|
||||
// User have created the htlc update therefore we should find the
|
||||
@ -407,6 +420,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
|
||||
// manages then channel.
|
||||
//
|
||||
// TODO(roasbeef): should return with an error
|
||||
packet.outgoingChanID = destination.ShortChanID()
|
||||
destination.HandleSwitchPacket(packet)
|
||||
return nil
|
||||
|
||||
@ -416,7 +430,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
|
||||
// Notify the user that his payment was successfully proceed.
|
||||
payment.err <- nil
|
||||
payment.preimage <- htlc.PaymentPreimage
|
||||
s.removePendingPayment(payment.amount, payment.paymentHash)
|
||||
s.removePendingPayment(packet.incomingHTLCID)
|
||||
|
||||
// We've just received a fail update which means we can finalize the
|
||||
// user payment and return fail response.
|
||||
@ -439,7 +453,7 @@ func (s *Switch) handleLocalDispatch(payment *pendingPayment, packet *htlcPacket
|
||||
}
|
||||
|
||||
payment.preimage <- zeroPreimage
|
||||
s.removePendingPayment(payment.amount, payment.paymentHash)
|
||||
s.removePendingPayment(packet.incomingHTLCID)
|
||||
|
||||
default:
|
||||
return errors.New("wrong update type")
|
||||
@ -458,6 +472,12 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
||||
// payment circuit within our internal state so we can properly forward
|
||||
// the ultimate settle message back latter.
|
||||
case *lnwire.UpdateAddHTLC:
|
||||
if packet.incomingChanID == (lnwire.ShortChannelID{}) {
|
||||
// A blank incomingChanID indicates that this is a pending
|
||||
// user-initiated payment.
|
||||
return s.handleLocalDispatch(packet)
|
||||
}
|
||||
|
||||
source, err := s.getLinkByShortID(packet.incomingChanID)
|
||||
if err != nil {
|
||||
err := errors.Errorf("unable to find channel link "+
|
||||
@ -581,15 +601,21 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
||||
circuit.OutgoingChanID)
|
||||
}
|
||||
|
||||
packet.incomingChanID = circuit.IncomingChanID
|
||||
packet.incomingHTLCID = circuit.IncomingHTLCID
|
||||
|
||||
// A blank IncomingChanID in a circuit indicates that it is a
|
||||
// pending user-initiated payment.
|
||||
if circuit.IncomingChanID == (lnwire.ShortChannelID{}) {
|
||||
return s.handleLocalDispatch(packet)
|
||||
}
|
||||
|
||||
// Obfuscate the error message for fail updates before sending back
|
||||
// through the circuit.
|
||||
if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated {
|
||||
htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
|
||||
htlc.Reason)
|
||||
}
|
||||
|
||||
packet.incomingChanID = circuit.IncomingChanID
|
||||
packet.incomingHTLCID = circuit.IncomingHTLCID
|
||||
}
|
||||
|
||||
source, err := s.getLinkByShortID(packet.incomingChanID)
|
||||
@ -696,37 +722,7 @@ func (s *Switch) htlcForwarder() {
|
||||
// packet concretely, then either forward it along, or
|
||||
// interpret a return packet to a locally initialized one.
|
||||
case cmd := <-s.htlcPlex:
|
||||
var (
|
||||
paymentHash lnwallet.PaymentHash
|
||||
amount lnwire.MilliSatoshi
|
||||
)
|
||||
|
||||
// Only three types of message should be forwarded:
|
||||
// add, fails, and settles. Anything else is an error.
|
||||
switch m := cmd.pkt.htlc.(type) {
|
||||
case *lnwire.UpdateAddHTLC:
|
||||
paymentHash = m.PaymentHash
|
||||
amount = m.Amount
|
||||
case *lnwire.UpdateFufillHTLC, *lnwire.UpdateFailHTLC:
|
||||
paymentHash = cmd.pkt.payHash
|
||||
amount = cmd.pkt.amount
|
||||
default:
|
||||
cmd.err <- errors.New("wrong type of update")
|
||||
return
|
||||
}
|
||||
|
||||
// If we can locate this packet in our local records,
|
||||
// then this means a local sub-system initiated it.
|
||||
// Otherwise, this is just a packet to be forwarded, so
|
||||
// we'll treat it as so.
|
||||
//
|
||||
// TODO(roasbeef): can fast path this
|
||||
payment, err := s.findPayment(amount, paymentHash)
|
||||
if err != nil {
|
||||
cmd.err <- s.handlePacketForward(cmd.pkt)
|
||||
} else {
|
||||
cmd.err <- s.handleLocalDispatch(payment, cmd.pkt)
|
||||
}
|
||||
cmd.err <- s.handlePacketForward(cmd.pkt)
|
||||
|
||||
// The log ticker has fired, so we'll calculate some forwarding
|
||||
// stats for the last 10 seconds to display within the logs to
|
||||
@ -1034,64 +1030,36 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
|
||||
|
||||
// removePendingPayment is the helper function which removes the pending user
|
||||
// payment.
|
||||
func (s *Switch) removePendingPayment(amount lnwire.MilliSatoshi,
|
||||
hash lnwallet.PaymentHash) error {
|
||||
|
||||
func (s *Switch) removePendingPayment(paymentID uint64) error {
|
||||
s.pendingMutex.Lock()
|
||||
defer s.pendingMutex.Unlock()
|
||||
|
||||
payments, ok := s.pendingPayments[hash]
|
||||
if ok {
|
||||
for i, payment := range payments {
|
||||
if payment.amount == amount {
|
||||
// Delete without preserving order
|
||||
// Google: Golang slice tricks
|
||||
payments[i] = payments[len(payments)-1]
|
||||
payments[len(payments)-1] = nil
|
||||
s.pendingPayments[hash] = payments[:len(payments)-1]
|
||||
|
||||
if len(s.pendingPayments[hash]) == 0 {
|
||||
delete(s.pendingPayments, hash)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if _, ok := s.pendingPayments[paymentID]; !ok {
|
||||
return errors.Errorf("Cannot find pending payment with ID %d",
|
||||
paymentID)
|
||||
}
|
||||
|
||||
return errors.Errorf("unable to remove pending payment with "+
|
||||
"hash(%v) and amount(%v)", hash, amount)
|
||||
delete(s.pendingPayments, paymentID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// findPayment is the helper function which find the payment.
|
||||
func (s *Switch) findPayment(amount lnwire.MilliSatoshi,
|
||||
hash lnwallet.PaymentHash) (*pendingPayment, error) {
|
||||
|
||||
func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) {
|
||||
s.pendingMutex.RLock()
|
||||
defer s.pendingMutex.RUnlock()
|
||||
|
||||
payments, ok := s.pendingPayments[hash]
|
||||
if ok {
|
||||
for _, payment := range payments {
|
||||
if payment.amount == amount {
|
||||
return payment, nil
|
||||
}
|
||||
}
|
||||
payment, ok := s.pendingPayments[paymentID]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("Cannot find pending payment with ID %d",
|
||||
paymentID)
|
||||
}
|
||||
|
||||
return nil, errors.Errorf("unable to remove pending payment with "+
|
||||
"hash(%v) and amount(%v)", hash, amount)
|
||||
return payment, nil
|
||||
}
|
||||
|
||||
// numPendingPayments is helper function which returns the overall number of
|
||||
// pending user payments.
|
||||
func (s *Switch) numPendingPayments() int {
|
||||
var l int
|
||||
for _, payments := range s.pendingPayments {
|
||||
l += len(payments)
|
||||
}
|
||||
|
||||
return l
|
||||
return len(s.pendingPayments)
|
||||
}
|
||||
|
||||
// addCircuit adds a circuit to the switch's in-memory mapping.
|
||||
|
Loading…
Reference in New Issue
Block a user