htlcswitch: Change circuit map keys to (channel ID, HTLC ID).
This changes the circuit map internals and API to reference circuits by a primary key of (channel ID, HTLC ID) instead of paymnet hash. This is because each circuit has a unique offered HTLC, but there may be multiple circuits for a payment hash with different source or destination channels.
This commit is contained in:
parent
bc8d674958
commit
1328e61c00
@ -1,135 +1,174 @@
|
|||||||
package htlcswitch
|
package htlcswitch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"fmt"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/go-errors/errors"
|
"github.com/go-errors/errors"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
)
|
)
|
||||||
|
|
||||||
// circuitKey uniquely identifies an active circuit between two open channels.
|
// PaymentCircuit is used by the HTLC switch subsystem to determine the
|
||||||
// Currently, the payment hash is used to uniquely identify each circuit.
|
// backwards path for the settle/fail HTLC messages. A payment circuit
|
||||||
type circuitKey [sha256.Size]byte
|
// will be created once a channel link forwards the HTLC add request and
|
||||||
|
// removed when we receive a settle/fail HTLC message.
|
||||||
// String returns the string representation of the circuitKey.
|
type PaymentCircuit struct {
|
||||||
func (k *circuitKey) String() string {
|
|
||||||
return hex.EncodeToString(k[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// paymentCircuit is used by the htlc switch subsystem to determine the
|
|
||||||
// forwards/backwards path for the settle/fail HTLC messages. A payment circuit
|
|
||||||
// will be created once a channel link forwards the htlc add request and
|
|
||||||
// removed when we receive settle/fail htlc message.
|
|
||||||
type paymentCircuit struct {
|
|
||||||
// PaymentHash used as unique identifier of payment.
|
// PaymentHash used as unique identifier of payment.
|
||||||
PaymentHash circuitKey
|
PaymentHash [32]byte
|
||||||
|
|
||||||
// Src identifies the channel from which add htlc request is came from
|
// IncomingChanID identifies the channel from which add HTLC request came
|
||||||
// and to which settle/fail htlc request will be returned back. Once
|
// and to which settle/fail HTLC request will be returned back. Once
|
||||||
// the switch forwards the settle/fail message to the src the circuit
|
// the switch forwards the settle/fail message to the src the circuit
|
||||||
// is considered to be completed.
|
// is considered to be completed.
|
||||||
Src lnwire.ShortChannelID
|
IncomingChanID lnwire.ShortChannelID
|
||||||
|
|
||||||
// Dest identifies the channel to which we propagate the htlc add
|
// IncomingHTLCID is the ID in the update_add_htlc message we received from
|
||||||
// update and from which we are expecting to receive htlc settle/fail
|
// the incoming channel, which will be included in any settle/fail messages
|
||||||
|
// we send back.
|
||||||
|
IncomingHTLCID uint64
|
||||||
|
|
||||||
|
// OutgoingChanID identifies the channel to which we propagate the HTLC add
|
||||||
|
// update and from which we are expecting to receive HTLC settle/fail
|
||||||
// request back.
|
// request back.
|
||||||
Dest lnwire.ShortChannelID
|
OutgoingChanID lnwire.ShortChannelID
|
||||||
|
|
||||||
|
// OutgoingHTLCID is the ID in the update_add_htlc message we sent to the
|
||||||
|
// outgoing channel.
|
||||||
|
OutgoingHTLCID uint64
|
||||||
|
|
||||||
// ErrorEncrypter is used to re-encrypt the onion failure before
|
// ErrorEncrypter is used to re-encrypt the onion failure before
|
||||||
// sending it back to the originator of the payment.
|
// sending it back to the originator of the payment.
|
||||||
ErrorEncrypter ErrorEncrypter
|
ErrorEncrypter ErrorEncrypter
|
||||||
|
|
||||||
// RefCount is used to count the circuits with the same circuit key.
|
|
||||||
RefCount int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newPaymentCircuit creates new payment circuit instance.
|
// circuitKey is a channel ID, HTLC ID tuple used as an identifying key for a
|
||||||
func newPaymentCircuit(src, dest lnwire.ShortChannelID, key circuitKey,
|
// payment circuit. The circuit map is keyed with the idenitifer for the
|
||||||
e ErrorEncrypter) *paymentCircuit {
|
// outgoing HTLC
|
||||||
|
type circuitKey struct {
|
||||||
return &paymentCircuit{
|
chanID lnwire.ShortChannelID
|
||||||
Src: src,
|
htlcID uint64
|
||||||
Dest: dest,
|
|
||||||
PaymentHash: key,
|
|
||||||
RefCount: 1,
|
|
||||||
ErrorEncrypter: e,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// isEqual checks the equality of two payment circuits.
|
// String returns a string representation of the circuitKey.
|
||||||
func (a *paymentCircuit) isEqual(b *paymentCircuit) bool {
|
func (k *circuitKey) String() string {
|
||||||
return bytes.Equal(a.PaymentHash[:], b.PaymentHash[:]) &&
|
return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.chanID, k.htlcID)
|
||||||
a.Src == b.Src &&
|
|
||||||
a.Dest == b.Dest
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// circuitMap is a data structure that implements thread safe storage of
|
// CircuitMap is a data structure that implements thread safe storage of
|
||||||
// circuits. Each circuit key (payment hash) may have several of circuits
|
// circuit routing information. The switch consults a circuit map to determine
|
||||||
// corresponding to it due to the possibility of repeated payment hashes.
|
// where to forward HTLC update messages. Each circuit is stored with it's
|
||||||
|
// outgoing HTLC as the primary key because, each offered HTLC has at most one
|
||||||
|
// received HTLC, but there may be multiple offered or received HTLCs with the
|
||||||
|
// same payment hash. Circuits are also indexed to provide fast lookups by
|
||||||
|
// payment hash.
|
||||||
//
|
//
|
||||||
// TODO(andrew.shvv) make it persistent
|
// TODO(andrew.shvv) make it persistent
|
||||||
type circuitMap struct {
|
type CircuitMap struct {
|
||||||
sync.RWMutex
|
mtx sync.RWMutex
|
||||||
circuits map[circuitKey]*paymentCircuit
|
circuits map[circuitKey]*PaymentCircuit
|
||||||
|
hashIndex map[[32]byte]map[PaymentCircuit]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// newCircuitMap creates a new instance of the circuitMap.
|
// NewCircuitMap creates a new instance of the CircuitMap.
|
||||||
func newCircuitMap() *circuitMap {
|
func NewCircuitMap() *CircuitMap {
|
||||||
return &circuitMap{
|
return &CircuitMap{
|
||||||
circuits: make(map[circuitKey]*paymentCircuit),
|
circuits: make(map[circuitKey]*PaymentCircuit),
|
||||||
|
hashIndex: make(map[[32]byte]map[PaymentCircuit]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// add adds a new active payment circuit to the circuitMap.
|
// LookupByHTLC looks up the payment circuit by the outgoing channel and HTLC
|
||||||
func (m *circuitMap) add(circuit *paymentCircuit) error {
|
// IDs. Returns nil if there is no such circuit.
|
||||||
m.Lock()
|
func (cm *CircuitMap) LookupByHTLC(chanID lnwire.ShortChannelID, htlcID uint64) *PaymentCircuit {
|
||||||
defer m.Unlock()
|
cm.mtx.RLock()
|
||||||
|
|
||||||
// Examine the circuit map to see if this circuit is already in use or
|
key := circuitKey{
|
||||||
// not. If so, then we'll simply increment the reference count.
|
chanID: chanID,
|
||||||
// Otherwise, we'll create a new circuit from scratch.
|
htlcID: htlcID,
|
||||||
//
|
}
|
||||||
// TODO(roasbeef): include dest+src+amt in key
|
circuit := cm.circuits[key]
|
||||||
if c, ok := m.circuits[circuit.PaymentHash]; ok {
|
|
||||||
c.RefCount++
|
cm.mtx.RUnlock()
|
||||||
return nil
|
return circuit
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupByPaymentHash looks up and returns any payment circuits with a given
|
||||||
|
// payment hash.
|
||||||
|
func (cm *CircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit {
|
||||||
|
cm.mtx.RLock()
|
||||||
|
|
||||||
|
var circuits []*PaymentCircuit
|
||||||
|
if circuitSet, ok := cm.hashIndex[hash]; ok {
|
||||||
|
circuits = make([]*PaymentCircuit, 0, len(circuitSet))
|
||||||
|
for circuit := range circuitSet {
|
||||||
|
circuits = append(circuits, &circuit)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.circuits[circuit.PaymentHash] = circuit
|
cm.mtx.RUnlock()
|
||||||
|
return circuits
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds a new active payment circuit to the CircuitMap.
|
||||||
|
func (cm *CircuitMap) Add(circuit *PaymentCircuit) error {
|
||||||
|
cm.mtx.Lock()
|
||||||
|
|
||||||
|
key := circuitKey{
|
||||||
|
chanID: circuit.OutgoingChanID,
|
||||||
|
htlcID: circuit.OutgoingHTLCID,
|
||||||
|
}
|
||||||
|
cm.circuits[key] = circuit
|
||||||
|
|
||||||
|
// Add circuit to the hash index.
|
||||||
|
if _, ok := cm.hashIndex[circuit.PaymentHash]; !ok {
|
||||||
|
cm.hashIndex[circuit.PaymentHash] = make(map[PaymentCircuit]struct{})
|
||||||
|
}
|
||||||
|
cm.hashIndex[circuit.PaymentHash][*circuit] = struct{}{}
|
||||||
|
|
||||||
|
cm.mtx.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove destroys the target circuit by removing it from the circuit map.
|
// Remove destroys the target circuit by removing it from the circuit map.
|
||||||
func (m *circuitMap) remove(key circuitKey) (*paymentCircuit, error) {
|
func (cm *CircuitMap) Remove(chanID lnwire.ShortChannelID, htlcID uint64) error {
|
||||||
m.Lock()
|
cm.mtx.Lock()
|
||||||
defer m.Unlock()
|
defer cm.mtx.Unlock()
|
||||||
|
|
||||||
if circuit, ok := m.circuits[key]; ok {
|
// Look up circuit so that pointer can be matched in the hash index.
|
||||||
if circuit.RefCount--; circuit.RefCount == 0 {
|
key := circuitKey{
|
||||||
delete(m.circuits, key)
|
chanID: chanID,
|
||||||
|
htlcID: htlcID,
|
||||||
|
}
|
||||||
|
circuit, found := cm.circuits[key]
|
||||||
|
if !found {
|
||||||
|
return errors.Errorf("Can't find circuit for HTLC %v", key)
|
||||||
|
}
|
||||||
|
delete(cm.circuits, key)
|
||||||
|
|
||||||
|
// Remove circuit from hash index.
|
||||||
|
circuitsWithHash, ok := cm.hashIndex[circuit.PaymentHash]
|
||||||
|
if !ok {
|
||||||
|
return errors.Errorf("Can't find circuit in hash index for HTLC %v",
|
||||||
|
key)
|
||||||
}
|
}
|
||||||
|
|
||||||
return circuit, nil
|
if _, ok = circuitsWithHash[*circuit]; !ok {
|
||||||
|
return errors.Errorf("Can't find circuit in hash index for HTLC %v",
|
||||||
|
key)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.Errorf("can't find circuit"+
|
delete(circuitsWithHash, *circuit)
|
||||||
" for key %v", key)
|
if len(circuitsWithHash) == 0 {
|
||||||
|
delete(cm.hashIndex, circuit.PaymentHash)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// pending returns number of circuits which are waiting for to be completed
|
// pending returns number of circuits which are waiting for to be completed
|
||||||
// (settle/fail responses to be received).
|
// (settle/fail responses to be received).
|
||||||
func (m *circuitMap) pending() int {
|
func (cm *CircuitMap) pending() int {
|
||||||
m.RLock()
|
cm.mtx.RLock()
|
||||||
defer m.RUnlock()
|
count := len(cm.circuits)
|
||||||
|
cm.mtx.RUnlock()
|
||||||
var length int
|
return count
|
||||||
for _, circuits := range m.circuits {
|
|
||||||
length += circuits.RefCount
|
|
||||||
}
|
|
||||||
|
|
||||||
return length
|
|
||||||
}
|
}
|
||||||
|
@ -700,6 +700,8 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
|
|||||||
|
|
||||||
failPkt := &htlcPacket{
|
failPkt := &htlcPacket{
|
||||||
src: l.ShortChanID(),
|
src: l.ShortChanID(),
|
||||||
|
dest: pkt.src,
|
||||||
|
destID: pkt.srcID,
|
||||||
payHash: htlc.PaymentHash,
|
payHash: htlc.PaymentHash,
|
||||||
amount: htlc.Amount,
|
amount: htlc.Amount,
|
||||||
isObfuscated: isObfuscated,
|
isObfuscated: isObfuscated,
|
||||||
@ -720,6 +722,20 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
|
|||||||
"local_log_index=%v, batch_size=%v",
|
"local_log_index=%v, batch_size=%v",
|
||||||
htlc.PaymentHash[:], index, l.batchCounter+1)
|
htlc.PaymentHash[:], index, l.batchCounter+1)
|
||||||
|
|
||||||
|
// If packet was forwarded from another channel link then we should
|
||||||
|
// create circuit (remember the path) in order to forward settle/fail
|
||||||
|
// packet back.
|
||||||
|
if pkt.src != (lnwire.ShortChannelID{}) {
|
||||||
|
l.cfg.Switch.addCircuit(&PaymentCircuit{
|
||||||
|
PaymentHash: htlc.PaymentHash,
|
||||||
|
IncomingChanID: pkt.src,
|
||||||
|
IncomingHTLCID: pkt.srcID,
|
||||||
|
OutgoingChanID: pkt.dest,
|
||||||
|
OutgoingHTLCID: index,
|
||||||
|
ErrorEncrypter: pkt.obfuscator,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
htlc.ID = index
|
htlc.ID = index
|
||||||
l.cfg.Peer.SendMessage(htlc)
|
l.cfg.Peer.SendMessage(htlc)
|
||||||
|
|
||||||
@ -1180,6 +1196,7 @@ func (l *channelLink) processLockedInHtlcs(
|
|||||||
case lnwallet.Settle:
|
case lnwallet.Settle:
|
||||||
settlePacket := &htlcPacket{
|
settlePacket := &htlcPacket{
|
||||||
src: l.ShortChanID(),
|
src: l.ShortChanID(),
|
||||||
|
srcID: pd.ParentIndex,
|
||||||
payHash: pd.RHash,
|
payHash: pd.RHash,
|
||||||
amount: pd.Amount,
|
amount: pd.Amount,
|
||||||
htlc: &lnwire.UpdateFufillHTLC{
|
htlc: &lnwire.UpdateFufillHTLC{
|
||||||
@ -1202,6 +1219,7 @@ func (l *channelLink) processLockedInHtlcs(
|
|||||||
// continue to propagate it.
|
// continue to propagate it.
|
||||||
failPacket := &htlcPacket{
|
failPacket := &htlcPacket{
|
||||||
src: l.ShortChanID(),
|
src: l.ShortChanID(),
|
||||||
|
srcID: pd.HtlcIndex,
|
||||||
payHash: pd.RHash,
|
payHash: pd.RHash,
|
||||||
amount: pd.Amount,
|
amount: pd.Amount,
|
||||||
isObfuscated: false,
|
isObfuscated: false,
|
||||||
@ -1573,6 +1591,7 @@ func (l *channelLink) processLockedInHtlcs(
|
|||||||
|
|
||||||
updatePacket := &htlcPacket{
|
updatePacket := &htlcPacket{
|
||||||
src: l.ShortChanID(),
|
src: l.ShortChanID(),
|
||||||
|
srcID: pd.HtlcIndex,
|
||||||
dest: fwdInfo.NextHop,
|
dest: fwdInfo.NextHop,
|
||||||
htlc: addMsg,
|
htlc: addMsg,
|
||||||
obfuscator: obfuscator,
|
obfuscator: obfuscator,
|
||||||
|
@ -27,6 +27,14 @@ type htlcPacket struct {
|
|||||||
// of the target link.
|
// of the target link.
|
||||||
src lnwire.ShortChannelID
|
src lnwire.ShortChannelID
|
||||||
|
|
||||||
|
// destID is the ID of the HTLC in the destination channel. This will be set
|
||||||
|
// when forwarding a settle or fail update back to the original source.
|
||||||
|
destID uint64
|
||||||
|
|
||||||
|
// srcID is the ID of the HTLC in the source channel. This will be set when
|
||||||
|
// forwarding any HTLC update message.
|
||||||
|
srcID uint64
|
||||||
|
|
||||||
// amount is the value of the HTLC that is being created or modified.
|
// amount is the value of the HTLC that is being created or modified.
|
||||||
amount lnwire.MilliSatoshi
|
amount lnwire.MilliSatoshi
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ type Switch struct {
|
|||||||
|
|
||||||
// circuits is storage for payment circuits which are used to
|
// circuits is storage for payment circuits which are used to
|
||||||
// forward the settle/fail htlc updates back to the add htlc initiator.
|
// forward the settle/fail htlc updates back to the add htlc initiator.
|
||||||
circuits *circuitMap
|
circuits *CircuitMap
|
||||||
|
|
||||||
// links is a map of channel id and channel link which manages
|
// links is a map of channel id and channel link which manages
|
||||||
// this channel.
|
// this channel.
|
||||||
@ -167,7 +167,7 @@ type Switch struct {
|
|||||||
func New(cfg Config) *Switch {
|
func New(cfg Config) *Switch {
|
||||||
return &Switch{
|
return &Switch{
|
||||||
cfg: &cfg,
|
cfg: &cfg,
|
||||||
circuits: newCircuitMap(),
|
circuits: NewCircuitMap(),
|
||||||
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
|
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
|
||||||
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
|
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
|
||||||
interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}),
|
interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}),
|
||||||
@ -481,7 +481,8 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
source.HandleSwitchPacket(&htlcPacket{
|
source.HandleSwitchPacket(&htlcPacket{
|
||||||
src: packet.src,
|
dest: packet.src,
|
||||||
|
destID: packet.srcID,
|
||||||
payHash: htlc.PaymentHash,
|
payHash: htlc.PaymentHash,
|
||||||
isObfuscated: true,
|
isObfuscated: true,
|
||||||
htlc: &lnwire.UpdateFailHTLC{
|
htlc: &lnwire.UpdateFailHTLC{
|
||||||
@ -529,7 +530,8 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
source.HandleSwitchPacket(&htlcPacket{
|
source.HandleSwitchPacket(&htlcPacket{
|
||||||
src: packet.src,
|
dest: packet.src,
|
||||||
|
destID: packet.srcID,
|
||||||
payHash: htlc.PaymentHash,
|
payHash: htlc.PaymentHash,
|
||||||
isObfuscated: true,
|
isObfuscated: true,
|
||||||
htlc: &lnwire.UpdateFailHTLC{
|
htlc: &lnwire.UpdateFailHTLC{
|
||||||
@ -544,38 +546,6 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If packet was forwarded from another channel link than we
|
|
||||||
// should create circuit (remember the path) in order to
|
|
||||||
// forward settle/fail packet back.
|
|
||||||
if err := s.circuits.add(newPaymentCircuit(
|
|
||||||
source.ShortChanID(),
|
|
||||||
destination.ShortChanID(),
|
|
||||||
htlc.PaymentHash,
|
|
||||||
packet.obfuscator,
|
|
||||||
)); err != nil {
|
|
||||||
failure := lnwire.NewTemporaryChannelFailure(nil)
|
|
||||||
reason, err := packet.obfuscator.EncryptFirstHop(failure)
|
|
||||||
if err != nil {
|
|
||||||
err := errors.Errorf("unable to obfuscate "+
|
|
||||||
"error: %v", err)
|
|
||||||
log.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
source.HandleSwitchPacket(&htlcPacket{
|
|
||||||
src: packet.src,
|
|
||||||
payHash: htlc.PaymentHash,
|
|
||||||
isObfuscated: true,
|
|
||||||
htlc: &lnwire.UpdateFailHTLC{
|
|
||||||
Reason: reason,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
err = errors.Errorf("unable to add circuit: "+
|
|
||||||
"%v", err)
|
|
||||||
log.Error(err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send the packet to the destination channel link which
|
// Send the packet to the destination channel link which
|
||||||
// manages the channel.
|
// manages the channel.
|
||||||
destination.HandleSwitchPacket(packet)
|
destination.HandleSwitchPacket(packet)
|
||||||
@ -585,37 +555,49 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
|
|||||||
// payment circuit by forwarding the settle msg to the channel from
|
// payment circuit by forwarding the settle msg to the channel from
|
||||||
// which htlc add packet was initially received.
|
// which htlc add packet was initially received.
|
||||||
case *lnwire.UpdateFufillHTLC, *lnwire.UpdateFailHTLC:
|
case *lnwire.UpdateFufillHTLC, *lnwire.UpdateFailHTLC:
|
||||||
// Exit if we can't find and remove the active circuit to
|
if packet.dest == (lnwire.ShortChannelID{}) {
|
||||||
// continue propagating the fail over.
|
// Use circuit map to find the link to forward settle/fail to.
|
||||||
circuit, err := s.circuits.remove(packet.payHash)
|
circuit := s.circuits.LookupByHTLC(packet.src, packet.srcID)
|
||||||
if err != nil {
|
if circuit == nil {
|
||||||
err := errors.Errorf("unable to remove "+
|
err := errors.Errorf("Unable to find source channel for HTLC "+
|
||||||
"circuit for payment hash: %v", packet.payHash)
|
"settle/fail: channel ID = %s, HTLC ID = %d, "+
|
||||||
|
"payment hash = %x", packet.src, packet.srcID,
|
||||||
|
packet.payHash[:])
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this is failure than we need to obfuscate the error.
|
// Remove circuit since we are about to complete the HTLC.
|
||||||
|
err := s.circuits.Remove(packet.src, packet.srcID)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to close completed onion circuit for %x: "+
|
||||||
|
"%s<->%s", packet.payHash[:], circuit.IncomingChanID,
|
||||||
|
circuit.OutgoingChanID)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Closed completed onion circuit for %x: %s<->%s",
|
||||||
|
packet.payHash[:], circuit.IncomingChanID,
|
||||||
|
circuit.OutgoingChanID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obfuscate the error message for fail updates before sending back
|
||||||
|
// through the circuit.
|
||||||
if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated {
|
if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated {
|
||||||
htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
|
htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
|
||||||
htlc.Reason,
|
htlc.Reason)
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Propagating settle/fail htlc back to src of add htlc packet.
|
packet.dest = circuit.IncomingChanID
|
||||||
source, err := s.getLinkByShortID(circuit.Src)
|
packet.destID = circuit.IncomingHTLCID
|
||||||
|
}
|
||||||
|
|
||||||
|
source, err := s.getLinkByShortID(packet.dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := errors.Errorf("unable to get source "+
|
err := errors.Errorf("Unable to get source channel link to "+
|
||||||
"channel link to forward settle/fail htlc: %v",
|
"forward HTLC settle/fail: %v", err)
|
||||||
err)
|
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Closing completed onion "+
|
|
||||||
"circuit for %x: %v<->%v", packet.payHash[:],
|
|
||||||
circuit.Src, circuit.Dest)
|
|
||||||
|
|
||||||
source.HandleSwitchPacket(packet)
|
source.HandleSwitchPacket(packet)
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
@ -1109,3 +1091,8 @@ func (s *Switch) numPendingPayments() int {
|
|||||||
|
|
||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addCircuit adds a circuit to the switch's in-memory mapping.
|
||||||
|
func (s *Switch) addCircuit(circuit *PaymentCircuit) {
|
||||||
|
s.circuits.Add(circuit)
|
||||||
|
}
|
||||||
|
@ -57,6 +57,7 @@ func TestSwitchForward(t *testing.T) {
|
|||||||
rhash := fastsha256.Sum256(preimage[:])
|
rhash := fastsha256.Sum256(preimage[:])
|
||||||
packet := &htlcPacket{
|
packet := &htlcPacket{
|
||||||
src: aliceChannelLink.ShortChanID(),
|
src: aliceChannelLink.ShortChanID(),
|
||||||
|
srcID: 0,
|
||||||
dest: bobChannelLink.ShortChanID(),
|
dest: bobChannelLink.ShortChanID(),
|
||||||
obfuscator: newMockObfuscator(),
|
obfuscator: newMockObfuscator(),
|
||||||
htlc: &lnwire.UpdateAddHTLC{
|
htlc: &lnwire.UpdateAddHTLC{
|
||||||
@ -70,6 +71,15 @@ func TestSwitchForward(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.addCircuit(&PaymentCircuit{
|
||||||
|
PaymentHash: packet.payHash,
|
||||||
|
IncomingChanID: packet.src,
|
||||||
|
IncomingHTLCID: 0,
|
||||||
|
OutgoingChanID: packet.dest,
|
||||||
|
OutgoingHTLCID: 0,
|
||||||
|
ErrorEncrypter: packet.obfuscator,
|
||||||
|
})
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-bobChannelLink.packets:
|
case <-bobChannelLink.packets:
|
||||||
break
|
break
|
||||||
@ -146,6 +156,7 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
|
|||||||
rhash := fastsha256.Sum256(preimage[:])
|
rhash := fastsha256.Sum256(preimage[:])
|
||||||
packet = &htlcPacket{
|
packet = &htlcPacket{
|
||||||
src: aliceChannelLink.ShortChanID(),
|
src: aliceChannelLink.ShortChanID(),
|
||||||
|
srcID: 0,
|
||||||
dest: bobChannelLink.ShortChanID(),
|
dest: bobChannelLink.ShortChanID(),
|
||||||
htlc: &lnwire.UpdateAddHTLC{
|
htlc: &lnwire.UpdateAddHTLC{
|
||||||
PaymentHash: rhash,
|
PaymentHash: rhash,
|
||||||
@ -234,6 +245,7 @@ func TestSwitchCancel(t *testing.T) {
|
|||||||
rhash := fastsha256.Sum256(preimage[:])
|
rhash := fastsha256.Sum256(preimage[:])
|
||||||
request := &htlcPacket{
|
request := &htlcPacket{
|
||||||
src: aliceChannelLink.ShortChanID(),
|
src: aliceChannelLink.ShortChanID(),
|
||||||
|
srcID: 0,
|
||||||
dest: bobChannelLink.ShortChanID(),
|
dest: bobChannelLink.ShortChanID(),
|
||||||
obfuscator: newMockObfuscator(),
|
obfuscator: newMockObfuscator(),
|
||||||
htlc: &lnwire.UpdateAddHTLC{
|
htlc: &lnwire.UpdateAddHTLC{
|
||||||
@ -247,6 +259,15 @@ func TestSwitchCancel(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.addCircuit(&PaymentCircuit{
|
||||||
|
PaymentHash: request.payHash,
|
||||||
|
IncomingChanID: request.src,
|
||||||
|
IncomingHTLCID: 0,
|
||||||
|
OutgoingChanID: request.dest,
|
||||||
|
OutgoingHTLCID: 0,
|
||||||
|
ErrorEncrypter: request.obfuscator,
|
||||||
|
})
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-bobChannelLink.packets:
|
case <-bobChannelLink.packets:
|
||||||
break
|
break
|
||||||
@ -263,6 +284,7 @@ func TestSwitchCancel(t *testing.T) {
|
|||||||
// request should be forwarder back to alice channel link.
|
// request should be forwarder back to alice channel link.
|
||||||
request = &htlcPacket{
|
request = &htlcPacket{
|
||||||
src: bobChannelLink.ShortChanID(),
|
src: bobChannelLink.ShortChanID(),
|
||||||
|
srcID: 0,
|
||||||
payHash: rhash,
|
payHash: rhash,
|
||||||
amount: 1,
|
amount: 1,
|
||||||
isObfuscated: true,
|
isObfuscated: true,
|
||||||
@ -316,6 +338,7 @@ func TestSwitchAddSamePayment(t *testing.T) {
|
|||||||
rhash := fastsha256.Sum256(preimage[:])
|
rhash := fastsha256.Sum256(preimage[:])
|
||||||
request := &htlcPacket{
|
request := &htlcPacket{
|
||||||
src: aliceChannelLink.ShortChanID(),
|
src: aliceChannelLink.ShortChanID(),
|
||||||
|
srcID: 0,
|
||||||
dest: bobChannelLink.ShortChanID(),
|
dest: bobChannelLink.ShortChanID(),
|
||||||
obfuscator: newMockObfuscator(),
|
obfuscator: newMockObfuscator(),
|
||||||
htlc: &lnwire.UpdateAddHTLC{
|
htlc: &lnwire.UpdateAddHTLC{
|
||||||
@ -329,6 +352,15 @@ func TestSwitchAddSamePayment(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.addCircuit(&PaymentCircuit{
|
||||||
|
PaymentHash: request.payHash,
|
||||||
|
IncomingChanID: request.src,
|
||||||
|
IncomingHTLCID: 0,
|
||||||
|
OutgoingChanID: request.dest,
|
||||||
|
OutgoingHTLCID: 0,
|
||||||
|
ErrorEncrypter: request.obfuscator,
|
||||||
|
})
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-bobChannelLink.packets:
|
case <-bobChannelLink.packets:
|
||||||
break
|
break
|
||||||
@ -340,11 +372,31 @@ func TestSwitchAddSamePayment(t *testing.T) {
|
|||||||
t.Fatal("wrong amount of circuits")
|
t.Fatal("wrong amount of circuits")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request = &htlcPacket{
|
||||||
|
src: aliceChannelLink.ShortChanID(),
|
||||||
|
srcID: 1,
|
||||||
|
dest: bobChannelLink.ShortChanID(),
|
||||||
|
obfuscator: newMockObfuscator(),
|
||||||
|
htlc: &lnwire.UpdateAddHTLC{
|
||||||
|
PaymentHash: rhash,
|
||||||
|
Amount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// Handle the request and checks that bob channel link received it.
|
// Handle the request and checks that bob channel link received it.
|
||||||
if err := s.forward(request); err != nil {
|
if err := s.forward(request); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.addCircuit(&PaymentCircuit{
|
||||||
|
PaymentHash: request.payHash,
|
||||||
|
IncomingChanID: request.src,
|
||||||
|
IncomingHTLCID: 1,
|
||||||
|
OutgoingChanID: request.dest,
|
||||||
|
OutgoingHTLCID: 1,
|
||||||
|
ErrorEncrypter: request.obfuscator,
|
||||||
|
})
|
||||||
|
|
||||||
if s.circuits.pending() != 2 {
|
if s.circuits.pending() != 2 {
|
||||||
t.Fatal("wrong amount of circuits")
|
t.Fatal("wrong amount of circuits")
|
||||||
}
|
}
|
||||||
@ -376,6 +428,15 @@ func TestSwitchAddSamePayment(t *testing.T) {
|
|||||||
t.Fatal("wrong amount of circuits")
|
t.Fatal("wrong amount of circuits")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
request = &htlcPacket{
|
||||||
|
src: bobChannelLink.ShortChanID(),
|
||||||
|
srcID: 1,
|
||||||
|
payHash: rhash,
|
||||||
|
amount: 1,
|
||||||
|
isObfuscated: true,
|
||||||
|
htlc: &lnwire.UpdateFailHTLC{},
|
||||||
|
}
|
||||||
|
|
||||||
// Handle the request and checks that payment circuit works properly.
|
// Handle the request and checks that payment circuit works properly.
|
||||||
if err := s.forward(request); err != nil {
|
if err := s.forward(request); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
package lnwire
|
package lnwire
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
// ShortChannelID represents the set of data which is needed to retrieve all
|
// ShortChannelID represents the set of data which is needed to retrieve all
|
||||||
// necessary data to validate the channel existence.
|
// necessary data to validate the channel existence.
|
||||||
type ShortChannelID struct {
|
type ShortChannelID struct {
|
||||||
@ -37,3 +41,8 @@ func (c *ShortChannelID) ToUint64() uint64 {
|
|||||||
return ((uint64(c.BlockHeight) << 40) | (uint64(c.TxIndex) << 16) |
|
return ((uint64(c.BlockHeight) << 40) | (uint64(c.TxIndex) << 16) |
|
||||||
(uint64(c.TxPosition)))
|
(uint64(c.TxPosition)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String generates a human-readable representation of the channel ID.
|
||||||
|
func (c ShortChannelID) String() string {
|
||||||
|
return fmt.Sprintf("%d:%d:%d", c.BlockHeight, c.TxIndex, c.TxPosition)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user