diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 03ddff35..96b38b11 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -1646,7 +1646,7 @@ func (lc *LightningChannel) SettleHTLC(preimage [32]byte) (uint32, error) { lc.Lock() defer lc.Unlock() - var targetHTLC *list.Element + var targetHTLC *PaymentDescriptor // TODO(roasbeef): optimize paymentHash := fastsha256.Sum256(preimage[:]) @@ -1658,7 +1658,7 @@ func (lc *LightningChannel) SettleHTLC(preimage [32]byte) (uint32, error) { if !htlc.settled && bytes.Equal(htlc.RHash[:], paymentHash[:]) { htlc.settled = true - targetHTLC = e + targetHTLC = htlc break } } @@ -1666,20 +1666,18 @@ func (lc *LightningChannel) SettleHTLC(preimage [32]byte) (uint32, error) { return 0, fmt.Errorf("invalid payment hash") } - parentPd := targetHTLC.Value.(*PaymentDescriptor) - pd := &PaymentDescriptor{ - Amount: parentPd.Amount, + Amount: targetHTLC.Amount, RPreimage: preimage, Index: lc.ourLogCounter, - ParentIndex: parentPd.Index, + ParentIndex: targetHTLC.Index, EntryType: Settle, } lc.ourUpdateLog.PushBack(pd) lc.ourLogCounter++ - return targetHTLC.Value.(*PaymentDescriptor).Index, nil + return targetHTLC.Index, nil } // ReceiveHTLCSettle attempts to settle an existing outgoing HTLC indexed by an @@ -1715,32 +1713,41 @@ func (lc *LightningChannel) ReceiveHTLCSettle(preimage [32]byte, logIndex uint32 return nil } -// CancelHTLC attempts to cancel a targeted HTLC by its log index, inserting an -// entry which will remove the target log entry within the next commitment +// CancelHTLC attempts to cancel a targeted HTLC by its payment hash, inserting +// an entry which will remove the target log entry within the next commitment // update. This method is intended to be called in order to cancel in // _incoming_ HTLC. -func (lc *LightningChannel) CancelHTLC(logIndex uint32) error { +func (lc *LightningChannel) CancelHTLC(rHash [32]byte) (uint32, error) { lc.Lock() defer lc.Unlock() - addEntry, ok := lc.theirLogIndex[logIndex] - if !ok { - return fmt.Errorf("unable to find HTLC to cancel") + var addEntry *PaymentDescriptor + for e := lc.theirUpdateLog.Front(); e != nil; e = e.Next() { + htlc := e.Value.(*PaymentDescriptor) + if htlc.EntryType != Add { + continue + } + + if !htlc.settled && bytes.Equal(htlc.RHash[:], rHash[:]) { + addEntry = htlc + break + } + } + if addEntry == nil { + return 0, fmt.Errorf("unable to find HTLC to cancel") } - htlc := addEntry.Value.(*PaymentDescriptor) - pd := &PaymentDescriptor{ - Amount: htlc.Amount, + Amount: addEntry.Amount, Index: lc.ourLogCounter, - ParentIndex: htlc.Index, + ParentIndex: addEntry.Index, EntryType: Cancel, } lc.ourUpdateLog.PushBack(pd) lc.ourLogCounter++ - return nil + return addEntry.Index, nil } // ReceiveCancelHTLC attempts to cancel a targeted HTLC by its log index, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index d10a89bd..b21d81ed 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -1180,12 +1180,12 @@ func TestCancelHTLC(t *testing.T) { Amount: htlcAmt, Expiry: 10, } + paymentHash := htlc.RedemptionHashes[0] if _, err := aliceChannel.AddHTLC(htlc); err != nil { t.Fatalf("unable to add alice htlc: %v", err) } - bobHtlcIndex, err := bobChannel.ReceiveHTLC(htlc) - if err != nil { + if _, err := bobChannel.ReceiveHTLC(htlc); err != nil { t.Fatalf("unable to add bob htlc: %v", err) } if err := forceStateTransition(aliceChannel, bobChannel); err != nil { @@ -1202,10 +1202,11 @@ func TestCancelHTLC(t *testing.T) { // Now, with the HTLC committed on both sides, trigger a cancellation // from Bob to Alice, removing the HTLC. - if err := bobChannel.CancelHTLC(bobHtlcIndex); err != nil { + htlcCancelIndex, err := bobChannel.CancelHTLC(paymentHash) + if err != nil { t.Fatalf("unable to cancel HTLC: %v", err) } - if err := aliceChannel.ReceiveCancelHTLC(bobHtlcIndex); err != nil { + if err := aliceChannel.ReceiveCancelHTLC(htlcCancelIndex); err != nil { t.Fatalf("unable to recv htlc cancel: %v", err) }