diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 8ccb0ded..d5b86e30 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -825,7 +825,7 @@ func (lc *LightningChannel) extractPayDescs(commitHeight uint64, return incomingHtlcs, outgoingHtlcs, nil } -// diskCommitToMemCommit converts tthe on-disk commitment format to our +// diskCommitToMemCommit converts the on-disk commitment format to our // in-memory commitment format which is needed in order to properly resume // channel operations after a restart. func (lc *LightningChannel) diskCommitToMemCommit(isLocal, isPendingCommit bool, @@ -1090,16 +1090,22 @@ type updateLog struct { // offerIndex is an index that maps the counter for offered HTLC's to // their list element within the main list.List. htlcIndex map[uint64]*list.Element + + // modifiedHtlcs is a set that keeps track of all the current modified + // htlcs. A modified HTLC is one that's present in the log, and has as + // a pending fail or settle that's attempting to consume it. + modifiedHtlcs map[uint64]struct{} } // newUpdateLog creates a new updateLog instance. func newUpdateLog(logIndex, htlcCounter uint64) *updateLog { return &updateLog{ - List: list.New(), - updateIndex: make(map[uint64]*list.Element), - htlcIndex: make(map[uint64]*list.Element), - logIndex: logIndex, - htlcCounter: htlcCounter, + List: list.New(), + updateIndex: make(map[uint64]*list.Element), + htlcIndex: make(map[uint64]*list.Element), + logIndex: logIndex, + htlcCounter: htlcCounter, + modifiedHtlcs: make(map[uint64]struct{}), } } @@ -1158,6 +1164,22 @@ func (u *updateLog) removeHtlc(i uint64) { entry := u.htlcIndex[i] u.Remove(entry) delete(u.htlcIndex, i) + + delete(u.modifiedHtlcs, i) +} + +// htlcHasModification returns true if the HTLC identified by the passed index +// has a pending modification within the log. +func (u *updateLog) htlcHasModification(i uint64) bool { + _, o := u.modifiedHtlcs[i] + return o +} + +// markHtlcModified marks an HTLC as modified based on its HTLC index. After a +// call to this method, htlcHasModification will return true until the HTLC is +// removed. +func (u *updateLog) markHtlcModified(i uint64) { + u.modifiedHtlcs[i] = struct{}{} } // compactLogs performs garbage collection within the log removing HTLCs which @@ -1509,10 +1531,12 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, if !isDustRemote { theirP2WSH, theirWitnessScript, err := genHtlcScript( false, false, wireMsg.Expiry, wireMsg.PaymentHash, - remoteCommitKeys) + remoteCommitKeys, + ) if err != nil { return nil, err } + pd.theirPkScript = theirP2WSH pd.theirWitnessScript = theirWitnessScript } @@ -1571,7 +1595,7 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, } // restoreCommitState will restore the local commitment chain and updateLog -// state to a consistent in-memory representation of the passed dis commitment. +// state to a consistent in-memory representation of the passed disk commitment. // This method is to be used upon reconnection to our channel counter party. // Once the connection has been established, we'll prepare our in memory state // to re-sync states with the remote party, and also verify/extend new proposed @@ -1749,9 +1773,12 @@ func (lc *LightningChannel) restoreStateLogs( "%v vs %v", payDesc.HtlcIndex, lc.localUpdateLog.htlcCounter)) } + lc.localUpdateLog.appendHtlc(payDesc) } else { lc.localUpdateLog.appendUpdate(payDesc) + + lc.remoteUpdateLog.markHtlcModified(payDesc.ParentIndex) } } @@ -4353,6 +4380,14 @@ func (lc *LightningChannel) SettleHTLC(preimage [32]byte, lc.ShortChanID()) } + // Now that we know the HTLC exists, before checking to see if the + // preimage matches, we'll ensure that we haven't already attempted to + // modify the HTLC. + if lc.remoteUpdateLog.htlcHasModification(htlcIndex) { + return fmt.Errorf("HTLC with ID %d has already been settled", + htlcIndex) + } + if htlc.RHash != sha256.Sum256(preimage[:]) { return fmt.Errorf("Invalid payment preimage %x for hash %x", preimage[:], htlc.RHash[:]) @@ -4371,6 +4406,11 @@ func (lc *LightningChannel) SettleHTLC(preimage [32]byte, lc.localUpdateLog.appendUpdate(pd) + // With the settle added to our local log, we'll now mark the HTLC as + // modified to prevent ourselves from accidentally attempting a + // duplicate settle. + lc.remoteUpdateLog.markHtlcModified(htlcIndex) + return nil } @@ -4388,6 +4428,14 @@ func (lc *LightningChannel) ReceiveHTLCSettle(preimage [32]byte, htlcIndex uint6 lc.ShortChanID()) } + // Now that we know the HTLC exists, before checking to see if the + // preimage matches, we'll ensure that they haven't already attempted + // to modify the HTLC. + if lc.localUpdateLog.htlcHasModification(htlcIndex) { + return fmt.Errorf("HTLC with ID %d has already been settled", + htlcIndex) + } + if htlc.RHash != sha256.Sum256(preimage[:]) { return fmt.Errorf("Invalid payment preimage %x for hash %x", preimage[:], htlc.RHash[:]) @@ -4403,6 +4451,12 @@ func (lc *LightningChannel) ReceiveHTLCSettle(preimage [32]byte, htlcIndex uint6 } lc.remoteUpdateLog.appendUpdate(pd) + + // With the settle added to the remote log, we'll now mark the HTLC as + // modified to prevent the remote party from accidentally attempting a + // duplicate settle. + lc.localUpdateLog.markHtlcModified(htlcIndex) + return nil } @@ -4442,6 +4496,13 @@ func (lc *LightningChannel) FailHTLC(htlcIndex uint64, reason []byte, lc.ShortChanID()) } + // Now that we know the HTLC exists, we'll ensure that we haven't + // already attempted to fail the HTLC. + if lc.remoteUpdateLog.htlcHasModification(htlcIndex) { + return fmt.Errorf("HTLC with ID %d has already been failed", + htlcIndex) + } + pd := &PaymentDescriptor{ Amount: htlc.Amount, RHash: htlc.RHash, @@ -4456,6 +4517,11 @@ func (lc *LightningChannel) FailHTLC(htlcIndex uint64, reason []byte, lc.localUpdateLog.appendUpdate(pd) + // With the fail added to the remote log, we'll now mark the HTLC as + // modified to prevent ourselves from accidentally attempting a + // duplicate fail. + lc.remoteUpdateLog.markHtlcModified(htlcIndex) + return nil } @@ -4482,6 +4548,13 @@ func (lc *LightningChannel) MalformedFailHTLC(htlcIndex uint64, lc.ShortChanID()) } + // Now that we know the HTLC exists, we'll ensure that we haven't + // already attempted to fail the HTLC. + if lc.remoteUpdateLog.htlcHasModification(htlcIndex) { + return fmt.Errorf("HTLC with ID %d has already been failed", + htlcIndex) + } + pd := &PaymentDescriptor{ Amount: htlc.Amount, RHash: htlc.RHash, @@ -4495,6 +4568,11 @@ func (lc *LightningChannel) MalformedFailHTLC(htlcIndex uint64, lc.localUpdateLog.appendUpdate(pd) + // With the fail added to the remote log, we'll now mark the HTLC as + // modified to prevent ourselves from accidentally attempting a + // duplicate fail. + lc.remoteUpdateLog.markHtlcModified(htlcIndex) + return nil } @@ -4515,6 +4593,13 @@ func (lc *LightningChannel) ReceiveFailHTLC(htlcIndex uint64, reason []byte, lc.ShortChanID()) } + // Now that we know the HTLC exists, we'll ensure that they haven't + // already attempted to fail the HTLC. + if lc.localUpdateLog.htlcHasModification(htlcIndex) { + return fmt.Errorf("HTLC with ID %d has already been failed", + htlcIndex) + } + pd := &PaymentDescriptor{ Amount: htlc.Amount, RHash: htlc.RHash, @@ -4526,6 +4611,11 @@ func (lc *LightningChannel) ReceiveFailHTLC(htlcIndex uint64, reason []byte, lc.remoteUpdateLog.appendUpdate(pd) + // With the fail added to the remote log, we'll now mark the HTLC as + // modified to prevent ourselves from accidentally attempting a + // duplicate fail. + lc.localUpdateLog.markHtlcModified(htlcIndex) + return nil } diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index f0b47d00..29fe4c5e 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -5342,3 +5342,166 @@ func TestChannelRestoreUpdateLogsFailedHTLC(t *testing.T) { assertInLogs(t, aliceChannel, 0, 0, 0, 0) restoreAndAssert(t, aliceChannel, 0, 0, 0, 0) } + +// TestDuplicateFailRejection tests that if either party attempts to fail an +// HTLC twice, then we'll reject the second fail attempt. +func TestDuplicateFailRejection(t *testing.T) { + t.Parallel() + + aliceChannel, bobChannel, cleanUp, err := CreateTestChannels() + if err != nil { + t.Fatalf("unable to create test channels: %v", err) + } + defer cleanUp() + + // First, we'll add an HTLC from Alice to Bob, and lock it in for both + // parties. + htlcAmount := lnwire.NewMSatFromSatoshis(20000) + htlcAlice, _ := createHTLC(0, htlcAmount) + if _, err := aliceChannel.AddHTLC(htlcAlice, nil); err != nil { + t.Fatalf("alice unable to add htlc: %v", err) + } + _, err = bobChannel.ReceiveHTLC(htlcAlice) + if err != nil { + t.Fatalf("unable to recv htlc: %v", err) + } + + if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + t.Fatalf("unable to complete state update: %v", err) + } + + // With the HTLC locked in, we'll now have Bob fail the HTLC back to + // Alice. + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + if err != nil { + t.Fatalf("unable to cancel HTLC: %v", err) + } + if err := aliceChannel.ReceiveFailHTLC(0, []byte("bad")); err != nil { + t.Fatalf("unable to recv htlc cancel: %v", err) + } + + // If we attempt to fail it AGAIN, then both sides should reject this + // second failure attempt. + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + if err == nil { + t.Fatalf("duplicate HTLC failure attempt should have failed") + } + if err := aliceChannel.ReceiveFailHTLC(0, []byte("bad")); err == nil { + t.Fatalf("duplicate HTLC failure attempt should have failed") + } + + // We'll now have Bob sign a new commitment to lock in the HTLC fail + // for Alice. + _, _, err = bobChannel.SignNextCommitment() + if err != nil { + t.Fatalf("unable to sign commit: %v", err) + } + + // We'll now force a restart for Bob and Alice, so we can test the + // persistence related portion of this assertion. + bobChannel, err = restartChannel(bobChannel) + if err != nil { + t.Fatalf("unable to restart channel: %v", err) + } + defer bobChannel.Stop() + aliceChannel, err = restartChannel(aliceChannel) + if err != nil { + t.Fatalf("unable to restart channel: %v", err) + } + defer aliceChannel.Stop() + + // If we try to fail the same HTLC again, then we should get an error. + err = bobChannel.FailHTLC(0, []byte("failreason"), nil, nil, nil) + if err == nil { + t.Fatalf("duplicate HTLC failure attempt should have failed") + } + + // Alice on the other hand should accept the failure again, as she + // dropped all items in the logs which weren't committed. + if err := aliceChannel.ReceiveFailHTLC(0, []byte("bad")); err != nil { + t.Fatalf("unable to recv htlc cancel: %v", err) + } +} + +// TestDuplicateSettleRejection tests that if either party attempts to settle +// an HTLC twice, then we'll reject the second settle attempt. +func TestDuplicateSettleRejection(t *testing.T) { + t.Parallel() + + aliceChannel, bobChannel, cleanUp, err := CreateTestChannels() + if err != nil { + t.Fatalf("unable to create test channels: %v", err) + } + defer cleanUp() + + // First, we'll add an HTLC from Alice to Bob, and lock it in for both + // parties. + htlcAmount := lnwire.NewMSatFromSatoshis(20000) + htlcAlice, alicePreimage := createHTLC(0, htlcAmount) + if _, err := aliceChannel.AddHTLC(htlcAlice, nil); err != nil { + t.Fatalf("alice unable to add htlc: %v", err) + } + _, err = bobChannel.ReceiveHTLC(htlcAlice) + if err != nil { + t.Fatalf("unable to recv htlc: %v", err) + } + + if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + t.Fatalf("unable to complete state update: %v", err) + } + + // With the HTLC locked in, we'll now have Bob settle the HTLC back to + // Alice. + err = bobChannel.SettleHTLC(alicePreimage, uint64(0), nil, nil, nil) + if err != nil { + t.Fatalf("unable to cancel HTLC: %v", err) + } + err = aliceChannel.ReceiveHTLCSettle(alicePreimage, uint64(0)) + if err != nil { + t.Fatalf("unable to recv htlc cancel: %v", err) + } + + // If we attempt to fail it AGAIN, then both sides should reject this + // second failure attempt. + err = bobChannel.SettleHTLC(alicePreimage, uint64(0), nil, nil, nil) + if err == nil { + t.Fatalf("duplicate HTLC failure attempt should have failed") + } + err = aliceChannel.ReceiveHTLCSettle(alicePreimage, uint64(0)) + if err == nil { + t.Fatalf("duplicate HTLC failure attempt should have failed") + } + + // We'll now have Bob sign a new commitment to lock in the HTLC fail + // for Alice. + _, _, err = bobChannel.SignNextCommitment() + if err != nil { + t.Fatalf("unable to sign commit: %v", err) + } + + // We'll now force a restart for Bob and Alice, so we can test the + // persistence related portion of this assertion. + bobChannel, err = restartChannel(bobChannel) + if err != nil { + t.Fatalf("unable to restart channel: %v", err) + } + defer bobChannel.Stop() + aliceChannel, err = restartChannel(aliceChannel) + if err != nil { + t.Fatalf("unable to restart channel: %v", err) + } + defer aliceChannel.Stop() + + // If we try to fail the same HTLC again, then we should get an error. + err = bobChannel.SettleHTLC(alicePreimage, uint64(0), nil, nil, nil) + if err == nil { + t.Fatalf("duplicate HTLC failure attempt should have failed") + } + + // Alice on the other hand should accept the failure again, as she + // dropped all items in the logs which weren't committed. + err = aliceChannel.ReceiveHTLCSettle(alicePreimage, uint64(0)) + if err != nil { + t.Fatalf("unable to recv htlc cancel: %v", err) + } +}