diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 54d42b52..0d13f6ba 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -152,7 +152,6 @@ type PaymentDescriptor struct { // isForwarded denotes if an incoming HTLC has been forwarded to any // possible upstream peers in the route. isForwarded bool - settled bool // pkScript is the raw public key script that encodes the redemption // rules for this particular HTLC. This field will only be populated @@ -409,6 +408,13 @@ type LightningChannel struct { ourLogIndex map[uint32]*list.Element theirLogIndex map[uint32]*list.Element + // rHashMap is a map with PaymentHashes pointing to their respective + // PaymentDescriptors. We insert *PaymentDescriptors whenever we + // receive HTLCs. When a state transition happens (settling or + // canceling the HTLC), rHashMap will provide an efficient + // way to lookup the original PaymentDescriptor. + rHashMap map[PaymentHash][]*PaymentDescriptor + LocalDeliveryScript []byte RemoteDeliveryScript []byte @@ -465,6 +471,7 @@ func NewLightningChannel(signer Signer, events chainntnfs.ChainNotifier, theirUpdateLog: list.New(), ourLogIndex: make(map[uint32]*list.Element), theirLogIndex: make(map[uint32]*list.Element), + rHashMap: make(map[PaymentHash][]*PaymentDescriptor), Capacity: state.Capacity, LocalDeliveryScript: state.OurDeliveryScript, RemoteDeliveryScript: state.TheirDeliveryScript, @@ -830,6 +837,7 @@ func (lc *LightningChannel) restoreStateLogs() error { } else { pd.Index = theirCounter lc.theirLogIndex[pd.Index] = lc.theirUpdateLog.PushBack(pd) + lc.rHashMap[pd.RHash] = append(lc.rHashMap[pd.RHash], pd) theirCounter++ } @@ -1675,6 +1683,8 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.HTLCAddRequest) (uint32, er lc.theirLogIndex[pd.Index] = lc.theirUpdateLog.PushBack(pd) lc.theirLogCounter++ + lc.rHashMap[pd.RHash] = append(lc.rHashMap[pd.RHash], pd) + return pd.Index, nil } @@ -1686,25 +1696,13 @@ func (lc *LightningChannel) SettleHTLC(preimage [32]byte) (uint32, error) { lc.Lock() defer lc.Unlock() - var targetHTLC *PaymentDescriptor - - // TODO(roasbeef): optimize paymentHash := fastsha256.Sum256(preimage[:]) - for e := lc.theirUpdateLog.Front(); e != nil; e = e.Next() { - htlc := e.Value.(*PaymentDescriptor) - if htlc.EntryType != Add { - continue - } - if !htlc.settled && bytes.Equal(htlc.RHash[:], paymentHash[:]) { - htlc.settled = true - targetHTLC = htlc - break - } - } - if targetHTLC == nil { + targetHTLCs, ok := lc.rHashMap[paymentHash] + if !ok { return 0, fmt.Errorf("invalid payment hash") } + targetHTLC := targetHTLCs[0] pd := &PaymentDescriptor{ Amount: targetHTLC.Amount, @@ -1717,6 +1715,12 @@ func (lc *LightningChannel) SettleHTLC(preimage [32]byte) (uint32, error) { lc.ourUpdateLog.PushBack(pd) lc.ourLogCounter++ + lc.rHashMap[paymentHash][0] = nil + lc.rHashMap[paymentHash] = lc.rHashMap[paymentHash][1:] + if len(lc.rHashMap[paymentHash]) == 0 { + delete(lc.rHashMap, paymentHash) + } + return targetHTLC.Index, nil } @@ -1761,21 +1765,12 @@ func (lc *LightningChannel) CancelHTLC(rHash [32]byte) (uint32, error) { lc.Lock() defer lc.Unlock() - var addEntry *PaymentDescriptor - for e := lc.theirUpdateLog.Front(); e != nil; e = e.Next() { - htlc := e.Value.(*PaymentDescriptor) - if htlc.EntryType != Add { - continue - } - - if !htlc.settled && bytes.Equal(htlc.RHash[:], rHash[:]) { - addEntry = htlc - break - } - } - if addEntry == nil { + addEntries, ok := lc.rHashMap[rHash] + if !ok { return 0, fmt.Errorf("unable to find HTLC to cancel") } + addEntry := addEntries[0] + pd := &PaymentDescriptor{ Amount: addEntry.Amount, RHash: addEntry.RHash, @@ -1787,6 +1782,12 @@ func (lc *LightningChannel) CancelHTLC(rHash [32]byte) (uint32, error) { lc.ourUpdateLog.PushBack(pd) lc.ourLogCounter++ + lc.rHashMap[rHash][0] = nil + lc.rHashMap[rHash] = lc.rHashMap[rHash][1:] + if len(lc.rHashMap[rHash]) == 0 { + delete(lc.rHashMap, rHash) + } + return addEntry.Index, nil }