htlcswitch: Change circuit map keys to (channel ID, HTLC ID).

This changes the circuit map internals and API to reference circuits
by a primary key of (channel ID, HTLC ID) instead of paymnet
hash. This is because each circuit has a unique offered HTLC, but
there may be multiple circuits for a payment hash with different
source or destination channels.
This commit is contained in:
Jim Posen 2017-10-23 15:50:26 -07:00 committed by Olaoluwa Osuntokun
parent bc8d674958
commit 1328e61c00
6 changed files with 274 additions and 151 deletions

@ -1,135 +1,174 @@
package htlcswitch
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"sync"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/lnwire"
)
// circuitKey uniquely identifies an active circuit between two open channels.
// Currently, the payment hash is used to uniquely identify each circuit.
type circuitKey [sha256.Size]byte
// String returns the string representation of the circuitKey.
func (k *circuitKey) String() string {
return hex.EncodeToString(k[:])
}
// paymentCircuit is used by the htlc switch subsystem to determine the
// forwards/backwards path for the settle/fail HTLC messages. A payment circuit
// will be created once a channel link forwards the htlc add request and
// removed when we receive settle/fail htlc message.
type paymentCircuit struct {
// PaymentCircuit is used by the HTLC switch subsystem to determine the
// backwards path for the settle/fail HTLC messages. A payment circuit
// will be created once a channel link forwards the HTLC add request and
// removed when we receive a settle/fail HTLC message.
type PaymentCircuit struct {
// PaymentHash used as unique identifier of payment.
PaymentHash circuitKey
PaymentHash [32]byte
// Src identifies the channel from which add htlc request is came from
// and to which settle/fail htlc request will be returned back. Once
// IncomingChanID identifies the channel from which add HTLC request came
// and to which settle/fail HTLC request will be returned back. Once
// the switch forwards the settle/fail message to the src the circuit
// is considered to be completed.
Src lnwire.ShortChannelID
IncomingChanID lnwire.ShortChannelID
// Dest identifies the channel to which we propagate the htlc add
// update and from which we are expecting to receive htlc settle/fail
// IncomingHTLCID is the ID in the update_add_htlc message we received from
// the incoming channel, which will be included in any settle/fail messages
// we send back.
IncomingHTLCID uint64
// OutgoingChanID identifies the channel to which we propagate the HTLC add
// update and from which we are expecting to receive HTLC settle/fail
// request back.
Dest lnwire.ShortChannelID
OutgoingChanID lnwire.ShortChannelID
// OutgoingHTLCID is the ID in the update_add_htlc message we sent to the
// outgoing channel.
OutgoingHTLCID uint64
// ErrorEncrypter is used to re-encrypt the onion failure before
// sending it back to the originator of the payment.
ErrorEncrypter ErrorEncrypter
// RefCount is used to count the circuits with the same circuit key.
RefCount int
}
// newPaymentCircuit creates new payment circuit instance.
func newPaymentCircuit(src, dest lnwire.ShortChannelID, key circuitKey,
e ErrorEncrypter) *paymentCircuit {
return &paymentCircuit{
Src: src,
Dest: dest,
PaymentHash: key,
RefCount: 1,
ErrorEncrypter: e,
}
// circuitKey is a channel ID, HTLC ID tuple used as an identifying key for a
// payment circuit. The circuit map is keyed with the idenitifer for the
// outgoing HTLC
type circuitKey struct {
chanID lnwire.ShortChannelID
htlcID uint64
}
// isEqual checks the equality of two payment circuits.
func (a *paymentCircuit) isEqual(b *paymentCircuit) bool {
return bytes.Equal(a.PaymentHash[:], b.PaymentHash[:]) &&
a.Src == b.Src &&
a.Dest == b.Dest
// String returns a string representation of the circuitKey.
func (k *circuitKey) String() string {
return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.chanID, k.htlcID)
}
// circuitMap is a data structure that implements thread safe storage of
// circuits. Each circuit key (payment hash) may have several of circuits
// corresponding to it due to the possibility of repeated payment hashes.
// CircuitMap is a data structure that implements thread safe storage of
// circuit routing information. The switch consults a circuit map to determine
// where to forward HTLC update messages. Each circuit is stored with it's
// outgoing HTLC as the primary key because, each offered HTLC has at most one
// received HTLC, but there may be multiple offered or received HTLCs with the
// same payment hash. Circuits are also indexed to provide fast lookups by
// payment hash.
//
// TODO(andrew.shvv) make it persistent
type circuitMap struct {
sync.RWMutex
circuits map[circuitKey]*paymentCircuit
type CircuitMap struct {
mtx sync.RWMutex
circuits map[circuitKey]*PaymentCircuit
hashIndex map[[32]byte]map[PaymentCircuit]struct{}
}
// newCircuitMap creates a new instance of the circuitMap.
func newCircuitMap() *circuitMap {
return &circuitMap{
circuits: make(map[circuitKey]*paymentCircuit),
// NewCircuitMap creates a new instance of the CircuitMap.
func NewCircuitMap() *CircuitMap {
return &CircuitMap{
circuits: make(map[circuitKey]*PaymentCircuit),
hashIndex: make(map[[32]byte]map[PaymentCircuit]struct{}),
}
}
// add adds a new active payment circuit to the circuitMap.
func (m *circuitMap) add(circuit *paymentCircuit) error {
m.Lock()
defer m.Unlock()
// LookupByHTLC looks up the payment circuit by the outgoing channel and HTLC
// IDs. Returns nil if there is no such circuit.
func (cm *CircuitMap) LookupByHTLC(chanID lnwire.ShortChannelID, htlcID uint64) *PaymentCircuit {
cm.mtx.RLock()
// Examine the circuit map to see if this circuit is already in use or
// not. If so, then we'll simply increment the reference count.
// Otherwise, we'll create a new circuit from scratch.
//
// TODO(roasbeef): include dest+src+amt in key
if c, ok := m.circuits[circuit.PaymentHash]; ok {
c.RefCount++
return nil
key := circuitKey{
chanID: chanID,
htlcID: htlcID,
}
circuit := cm.circuits[key]
cm.mtx.RUnlock()
return circuit
}
// LookupByPaymentHash looks up and returns any payment circuits with a given
// payment hash.
func (cm *CircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit {
cm.mtx.RLock()
var circuits []*PaymentCircuit
if circuitSet, ok := cm.hashIndex[hash]; ok {
circuits = make([]*PaymentCircuit, 0, len(circuitSet))
for circuit := range circuitSet {
circuits = append(circuits, &circuit)
}
}
m.circuits[circuit.PaymentHash] = circuit
cm.mtx.RUnlock()
return circuits
}
// Add adds a new active payment circuit to the CircuitMap.
func (cm *CircuitMap) Add(circuit *PaymentCircuit) error {
cm.mtx.Lock()
key := circuitKey{
chanID: circuit.OutgoingChanID,
htlcID: circuit.OutgoingHTLCID,
}
cm.circuits[key] = circuit
// Add circuit to the hash index.
if _, ok := cm.hashIndex[circuit.PaymentHash]; !ok {
cm.hashIndex[circuit.PaymentHash] = make(map[PaymentCircuit]struct{})
}
cm.hashIndex[circuit.PaymentHash][*circuit] = struct{}{}
cm.mtx.Unlock()
return nil
}
// remove destroys the target circuit by removing it from the circuit map.
func (m *circuitMap) remove(key circuitKey) (*paymentCircuit, error) {
m.Lock()
defer m.Unlock()
// Remove destroys the target circuit by removing it from the circuit map.
func (cm *CircuitMap) Remove(chanID lnwire.ShortChannelID, htlcID uint64) error {
cm.mtx.Lock()
defer cm.mtx.Unlock()
if circuit, ok := m.circuits[key]; ok {
if circuit.RefCount--; circuit.RefCount == 0 {
delete(m.circuits, key)
// Look up circuit so that pointer can be matched in the hash index.
key := circuitKey{
chanID: chanID,
htlcID: htlcID,
}
circuit, found := cm.circuits[key]
if !found {
return errors.Errorf("Can't find circuit for HTLC %v", key)
}
delete(cm.circuits, key)
// Remove circuit from hash index.
circuitsWithHash, ok := cm.hashIndex[circuit.PaymentHash]
if !ok {
return errors.Errorf("Can't find circuit in hash index for HTLC %v",
key)
}
return circuit, nil
if _, ok = circuitsWithHash[*circuit]; !ok {
return errors.Errorf("Can't find circuit in hash index for HTLC %v",
key)
}
return nil, errors.Errorf("can't find circuit"+
" for key %v", key)
delete(circuitsWithHash, *circuit)
if len(circuitsWithHash) == 0 {
delete(cm.hashIndex, circuit.PaymentHash)
}
return nil
}
// pending returns number of circuits which are waiting for to be completed
// (settle/fail responses to be received).
func (m *circuitMap) pending() int {
m.RLock()
defer m.RUnlock()
var length int
for _, circuits := range m.circuits {
length += circuits.RefCount
}
return length
func (cm *CircuitMap) pending() int {
cm.mtx.RLock()
count := len(cm.circuits)
cm.mtx.RUnlock()
return count
}

@ -700,6 +700,8 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
failPkt := &htlcPacket{
src: l.ShortChanID(),
dest: pkt.src,
destID: pkt.srcID,
payHash: htlc.PaymentHash,
amount: htlc.Amount,
isObfuscated: isObfuscated,
@ -720,6 +722,20 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) {
"local_log_index=%v, batch_size=%v",
htlc.PaymentHash[:], index, l.batchCounter+1)
// If packet was forwarded from another channel link then we should
// create circuit (remember the path) in order to forward settle/fail
// packet back.
if pkt.src != (lnwire.ShortChannelID{}) {
l.cfg.Switch.addCircuit(&PaymentCircuit{
PaymentHash: htlc.PaymentHash,
IncomingChanID: pkt.src,
IncomingHTLCID: pkt.srcID,
OutgoingChanID: pkt.dest,
OutgoingHTLCID: index,
ErrorEncrypter: pkt.obfuscator,
})
}
htlc.ID = index
l.cfg.Peer.SendMessage(htlc)
@ -1180,6 +1196,7 @@ func (l *channelLink) processLockedInHtlcs(
case lnwallet.Settle:
settlePacket := &htlcPacket{
src: l.ShortChanID(),
srcID: pd.ParentIndex,
payHash: pd.RHash,
amount: pd.Amount,
htlc: &lnwire.UpdateFufillHTLC{
@ -1202,6 +1219,7 @@ func (l *channelLink) processLockedInHtlcs(
// continue to propagate it.
failPacket := &htlcPacket{
src: l.ShortChanID(),
srcID: pd.HtlcIndex,
payHash: pd.RHash,
amount: pd.Amount,
isObfuscated: false,
@ -1573,6 +1591,7 @@ func (l *channelLink) processLockedInHtlcs(
updatePacket := &htlcPacket{
src: l.ShortChanID(),
srcID: pd.HtlcIndex,
dest: fwdInfo.NextHop,
htlc: addMsg,
obfuscator: obfuscator,

@ -27,6 +27,14 @@ type htlcPacket struct {
// of the target link.
src lnwire.ShortChannelID
// destID is the ID of the HTLC in the destination channel. This will be set
// when forwarding a settle or fail update back to the original source.
destID uint64
// srcID is the ID of the HTLC in the source channel. This will be set when
// forwarding any HTLC update message.
srcID uint64
// amount is the value of the HTLC that is being created or modified.
amount lnwire.MilliSatoshi

@ -129,7 +129,7 @@ type Switch struct {
// circuits is storage for payment circuits which are used to
// forward the settle/fail htlc updates back to the add htlc initiator.
circuits *circuitMap
circuits *CircuitMap
// links is a map of channel id and channel link which manages
// this channel.
@ -167,7 +167,7 @@ type Switch struct {
func New(cfg Config) *Switch {
return &Switch{
cfg: &cfg,
circuits: newCircuitMap(),
circuits: NewCircuitMap(),
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}),
@ -481,7 +481,8 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
}
source.HandleSwitchPacket(&htlcPacket{
src: packet.src,
dest: packet.src,
destID: packet.srcID,
payHash: htlc.PaymentHash,
isObfuscated: true,
htlc: &lnwire.UpdateFailHTLC{
@ -529,7 +530,8 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
}
source.HandleSwitchPacket(&htlcPacket{
src: packet.src,
dest: packet.src,
destID: packet.srcID,
payHash: htlc.PaymentHash,
isObfuscated: true,
htlc: &lnwire.UpdateFailHTLC{
@ -544,38 +546,6 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
return err
}
// If packet was forwarded from another channel link than we
// should create circuit (remember the path) in order to
// forward settle/fail packet back.
if err := s.circuits.add(newPaymentCircuit(
source.ShortChanID(),
destination.ShortChanID(),
htlc.PaymentHash,
packet.obfuscator,
)); err != nil {
failure := lnwire.NewTemporaryChannelFailure(nil)
reason, err := packet.obfuscator.EncryptFirstHop(failure)
if err != nil {
err := errors.Errorf("unable to obfuscate "+
"error: %v", err)
log.Error(err)
return err
}
source.HandleSwitchPacket(&htlcPacket{
src: packet.src,
payHash: htlc.PaymentHash,
isObfuscated: true,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
})
err = errors.Errorf("unable to add circuit: "+
"%v", err)
log.Error(err)
return err
}
// Send the packet to the destination channel link which
// manages the channel.
destination.HandleSwitchPacket(packet)
@ -585,37 +555,49 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
// payment circuit by forwarding the settle msg to the channel from
// which htlc add packet was initially received.
case *lnwire.UpdateFufillHTLC, *lnwire.UpdateFailHTLC:
// Exit if we can't find and remove the active circuit to
// continue propagating the fail over.
circuit, err := s.circuits.remove(packet.payHash)
if err != nil {
err := errors.Errorf("unable to remove "+
"circuit for payment hash: %v", packet.payHash)
if packet.dest == (lnwire.ShortChannelID{}) {
// Use circuit map to find the link to forward settle/fail to.
circuit := s.circuits.LookupByHTLC(packet.src, packet.srcID)
if circuit == nil {
err := errors.Errorf("Unable to find source channel for HTLC "+
"settle/fail: channel ID = %s, HTLC ID = %d, "+
"payment hash = %x", packet.src, packet.srcID,
packet.payHash[:])
log.Error(err)
return err
}
// If this is failure than we need to obfuscate the error.
// Remove circuit since we are about to complete the HTLC.
err := s.circuits.Remove(packet.src, packet.srcID)
if err != nil {
log.Warnf("Failed to close completed onion circuit for %x: "+
"%s<->%s", packet.payHash[:], circuit.IncomingChanID,
circuit.OutgoingChanID)
} else {
log.Debugf("Closed completed onion circuit for %x: %s<->%s",
packet.payHash[:], circuit.IncomingChanID,
circuit.OutgoingChanID)
}
// Obfuscate the error message for fail updates before sending back
// through the circuit.
if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated {
htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt(
htlc.Reason,
)
htlc.Reason)
}
// Propagating settle/fail htlc back to src of add htlc packet.
source, err := s.getLinkByShortID(circuit.Src)
packet.dest = circuit.IncomingChanID
packet.destID = circuit.IncomingHTLCID
}
source, err := s.getLinkByShortID(packet.dest)
if err != nil {
err := errors.Errorf("unable to get source "+
"channel link to forward settle/fail htlc: %v",
err)
err := errors.Errorf("Unable to get source channel link to "+
"forward HTLC settle/fail: %v", err)
log.Error(err)
return err
}
log.Debugf("Closing completed onion "+
"circuit for %x: %v<->%v", packet.payHash[:],
circuit.Src, circuit.Dest)
source.HandleSwitchPacket(packet)
return nil
@ -1109,3 +1091,8 @@ func (s *Switch) numPendingPayments() int {
return l
}
// addCircuit adds a circuit to the switch's in-memory mapping.
func (s *Switch) addCircuit(circuit *PaymentCircuit) {
s.circuits.Add(circuit)
}

@ -57,6 +57,7 @@ func TestSwitchForward(t *testing.T) {
rhash := fastsha256.Sum256(preimage[:])
packet := &htlcPacket{
src: aliceChannelLink.ShortChanID(),
srcID: 0,
dest: bobChannelLink.ShortChanID(),
obfuscator: newMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
@ -70,6 +71,15 @@ func TestSwitchForward(t *testing.T) {
t.Fatal(err)
}
s.addCircuit(&PaymentCircuit{
PaymentHash: packet.payHash,
IncomingChanID: packet.src,
IncomingHTLCID: 0,
OutgoingChanID: packet.dest,
OutgoingHTLCID: 0,
ErrorEncrypter: packet.obfuscator,
})
select {
case <-bobChannelLink.packets:
break
@ -146,6 +156,7 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
rhash := fastsha256.Sum256(preimage[:])
packet = &htlcPacket{
src: aliceChannelLink.ShortChanID(),
srcID: 0,
dest: bobChannelLink.ShortChanID(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
@ -234,6 +245,7 @@ func TestSwitchCancel(t *testing.T) {
rhash := fastsha256.Sum256(preimage[:])
request := &htlcPacket{
src: aliceChannelLink.ShortChanID(),
srcID: 0,
dest: bobChannelLink.ShortChanID(),
obfuscator: newMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
@ -247,6 +259,15 @@ func TestSwitchCancel(t *testing.T) {
t.Fatal(err)
}
s.addCircuit(&PaymentCircuit{
PaymentHash: request.payHash,
IncomingChanID: request.src,
IncomingHTLCID: 0,
OutgoingChanID: request.dest,
OutgoingHTLCID: 0,
ErrorEncrypter: request.obfuscator,
})
select {
case <-bobChannelLink.packets:
break
@ -263,6 +284,7 @@ func TestSwitchCancel(t *testing.T) {
// request should be forwarder back to alice channel link.
request = &htlcPacket{
src: bobChannelLink.ShortChanID(),
srcID: 0,
payHash: rhash,
amount: 1,
isObfuscated: true,
@ -316,6 +338,7 @@ func TestSwitchAddSamePayment(t *testing.T) {
rhash := fastsha256.Sum256(preimage[:])
request := &htlcPacket{
src: aliceChannelLink.ShortChanID(),
srcID: 0,
dest: bobChannelLink.ShortChanID(),
obfuscator: newMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
@ -329,6 +352,15 @@ func TestSwitchAddSamePayment(t *testing.T) {
t.Fatal(err)
}
s.addCircuit(&PaymentCircuit{
PaymentHash: request.payHash,
IncomingChanID: request.src,
IncomingHTLCID: 0,
OutgoingChanID: request.dest,
OutgoingHTLCID: 0,
ErrorEncrypter: request.obfuscator,
})
select {
case <-bobChannelLink.packets:
break
@ -340,11 +372,31 @@ func TestSwitchAddSamePayment(t *testing.T) {
t.Fatal("wrong amount of circuits")
}
request = &htlcPacket{
src: aliceChannelLink.ShortChanID(),
srcID: 1,
dest: bobChannelLink.ShortChanID(),
obfuscator: newMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
// Handle the request and checks that bob channel link received it.
if err := s.forward(request); err != nil {
t.Fatal(err)
}
s.addCircuit(&PaymentCircuit{
PaymentHash: request.payHash,
IncomingChanID: request.src,
IncomingHTLCID: 1,
OutgoingChanID: request.dest,
OutgoingHTLCID: 1,
ErrorEncrypter: request.obfuscator,
})
if s.circuits.pending() != 2 {
t.Fatal("wrong amount of circuits")
}
@ -376,6 +428,15 @@ func TestSwitchAddSamePayment(t *testing.T) {
t.Fatal("wrong amount of circuits")
}
request = &htlcPacket{
src: bobChannelLink.ShortChanID(),
srcID: 1,
payHash: rhash,
amount: 1,
isObfuscated: true,
htlc: &lnwire.UpdateFailHTLC{},
}
// Handle the request and checks that payment circuit works properly.
if err := s.forward(request); err != nil {
t.Fatal(err)

@ -1,5 +1,9 @@
package lnwire
import (
"fmt"
)
// ShortChannelID represents the set of data which is needed to retrieve all
// necessary data to validate the channel existence.
type ShortChannelID struct {
@ -37,3 +41,8 @@ func (c *ShortChannelID) ToUint64() uint64 {
return ((uint64(c.BlockHeight) << 40) | (uint64(c.TxIndex) << 16) |
(uint64(c.TxPosition)))
}
// String generates a human-readable representation of the channel ID.
func (c ShortChannelID) String() string {
return fmt.Sprintf("%d:%d:%d", c.BlockHeight, c.TxIndex, c.TxPosition)
}