lnwallet: optimize PaymentDescriptor lookup on HTLC state transitions

rHashMap is used to store the PaymentDescriptor belonging to a received
HTLC's revocation hash. This improves the efficiency of looking up
PaymentDescriptors from their RHash whenever we want to settle or cancel
that HTLC.
This commit is contained in:
Christopher Jämthagen 2017-01-16 12:57:26 +01:00 committed by Olaoluwa Osuntokun
parent 59615b3cb2
commit a2403d9c07

@ -152,7 +152,6 @@ type PaymentDescriptor struct {
// isForwarded denotes if an incoming HTLC has been forwarded to any // isForwarded denotes if an incoming HTLC has been forwarded to any
// possible upstream peers in the route. // possible upstream peers in the route.
isForwarded bool isForwarded bool
settled bool
// pkScript is the raw public key script that encodes the redemption // pkScript is the raw public key script that encodes the redemption
// rules for this particular HTLC. This field will only be populated // rules for this particular HTLC. This field will only be populated
@ -409,6 +408,13 @@ type LightningChannel struct {
ourLogIndex map[uint32]*list.Element ourLogIndex map[uint32]*list.Element
theirLogIndex map[uint32]*list.Element theirLogIndex map[uint32]*list.Element
// rHashMap is a map with PaymentHashes pointing to their respective
// PaymentDescriptors. We insert *PaymentDescriptors whenever we
// receive HTLCs. When a state transition happens (settling or
// canceling the HTLC), rHashMap will provide an efficient
// way to lookup the original PaymentDescriptor.
rHashMap map[PaymentHash][]*PaymentDescriptor
LocalDeliveryScript []byte LocalDeliveryScript []byte
RemoteDeliveryScript []byte RemoteDeliveryScript []byte
@ -465,6 +471,7 @@ func NewLightningChannel(signer Signer, events chainntnfs.ChainNotifier,
theirUpdateLog: list.New(), theirUpdateLog: list.New(),
ourLogIndex: make(map[uint32]*list.Element), ourLogIndex: make(map[uint32]*list.Element),
theirLogIndex: make(map[uint32]*list.Element), theirLogIndex: make(map[uint32]*list.Element),
rHashMap: make(map[PaymentHash][]*PaymentDescriptor),
Capacity: state.Capacity, Capacity: state.Capacity,
LocalDeliveryScript: state.OurDeliveryScript, LocalDeliveryScript: state.OurDeliveryScript,
RemoteDeliveryScript: state.TheirDeliveryScript, RemoteDeliveryScript: state.TheirDeliveryScript,
@ -830,6 +837,7 @@ func (lc *LightningChannel) restoreStateLogs() error {
} else { } else {
pd.Index = theirCounter pd.Index = theirCounter
lc.theirLogIndex[pd.Index] = lc.theirUpdateLog.PushBack(pd) lc.theirLogIndex[pd.Index] = lc.theirUpdateLog.PushBack(pd)
lc.rHashMap[pd.RHash] = append(lc.rHashMap[pd.RHash], pd)
theirCounter++ theirCounter++
} }
@ -1675,6 +1683,8 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.HTLCAddRequest) (uint32, er
lc.theirLogIndex[pd.Index] = lc.theirUpdateLog.PushBack(pd) lc.theirLogIndex[pd.Index] = lc.theirUpdateLog.PushBack(pd)
lc.theirLogCounter++ lc.theirLogCounter++
lc.rHashMap[pd.RHash] = append(lc.rHashMap[pd.RHash], pd)
return pd.Index, nil return pd.Index, nil
} }
@ -1686,25 +1696,13 @@ func (lc *LightningChannel) SettleHTLC(preimage [32]byte) (uint32, error) {
lc.Lock() lc.Lock()
defer lc.Unlock() defer lc.Unlock()
var targetHTLC *PaymentDescriptor
// TODO(roasbeef): optimize
paymentHash := fastsha256.Sum256(preimage[:]) paymentHash := fastsha256.Sum256(preimage[:])
for e := lc.theirUpdateLog.Front(); e != nil; e = e.Next() {
htlc := e.Value.(*PaymentDescriptor)
if htlc.EntryType != Add {
continue
}
if !htlc.settled && bytes.Equal(htlc.RHash[:], paymentHash[:]) { targetHTLCs, ok := lc.rHashMap[paymentHash]
htlc.settled = true if !ok {
targetHTLC = htlc
break
}
}
if targetHTLC == nil {
return 0, fmt.Errorf("invalid payment hash") return 0, fmt.Errorf("invalid payment hash")
} }
targetHTLC := targetHTLCs[0]
pd := &PaymentDescriptor{ pd := &PaymentDescriptor{
Amount: targetHTLC.Amount, Amount: targetHTLC.Amount,
@ -1717,6 +1715,12 @@ func (lc *LightningChannel) SettleHTLC(preimage [32]byte) (uint32, error) {
lc.ourUpdateLog.PushBack(pd) lc.ourUpdateLog.PushBack(pd)
lc.ourLogCounter++ lc.ourLogCounter++
lc.rHashMap[paymentHash][0] = nil
lc.rHashMap[paymentHash] = lc.rHashMap[paymentHash][1:]
if len(lc.rHashMap[paymentHash]) == 0 {
delete(lc.rHashMap, paymentHash)
}
return targetHTLC.Index, nil return targetHTLC.Index, nil
} }
@ -1761,21 +1765,12 @@ func (lc *LightningChannel) CancelHTLC(rHash [32]byte) (uint32, error) {
lc.Lock() lc.Lock()
defer lc.Unlock() defer lc.Unlock()
var addEntry *PaymentDescriptor addEntries, ok := lc.rHashMap[rHash]
for e := lc.theirUpdateLog.Front(); e != nil; e = e.Next() { if !ok {
htlc := e.Value.(*PaymentDescriptor)
if htlc.EntryType != Add {
continue
}
if !htlc.settled && bytes.Equal(htlc.RHash[:], rHash[:]) {
addEntry = htlc
break
}
}
if addEntry == nil {
return 0, fmt.Errorf("unable to find HTLC to cancel") return 0, fmt.Errorf("unable to find HTLC to cancel")
} }
addEntry := addEntries[0]
pd := &PaymentDescriptor{ pd := &PaymentDescriptor{
Amount: addEntry.Amount, Amount: addEntry.Amount,
RHash: addEntry.RHash, RHash: addEntry.RHash,
@ -1787,6 +1782,12 @@ func (lc *LightningChannel) CancelHTLC(rHash [32]byte) (uint32, error) {
lc.ourUpdateLog.PushBack(pd) lc.ourUpdateLog.PushBack(pd)
lc.ourLogCounter++ lc.ourLogCounter++
lc.rHashMap[rHash][0] = nil
lc.rHashMap[rHash] = lc.rHashMap[rHash][1:]
if len(lc.rHashMap[rHash]) == 0 {
delete(lc.rHashMap, rHash)
}
return addEntry.Index, nil return addEntry.Index, nil
} }