diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 69d10d1c..dcf96d49 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -2,7 +2,9 @@ package lnwallet import ( "bytes" + "container/list" "crypto/sha256" + "fmt" "reflect" "runtime" @@ -4906,3 +4908,203 @@ func TestNewBreachRetributionSkipsDustHtlcs(t *testing.T) { "instead %v were", len(breachRet.HtlcRetributions)) } } + +// compareHtlcs compares two PaymentDescriptors. +func compareHtlcs(htlc1, htlc2 *PaymentDescriptor) error { + if htlc1.LogIndex != htlc2.LogIndex { + return fmt.Errorf("htlc log index did not match") + } + if htlc1.HtlcIndex != htlc2.HtlcIndex { + return fmt.Errorf("htlc index did not match") + } + if htlc1.ParentIndex != htlc2.ParentIndex { + return fmt.Errorf("htlc parent index did not match") + } + + if htlc1.RHash != htlc2.RHash { + return fmt.Errorf("htlc rhash did not match") + } + return nil +} + +// compareIndexes is a helper method to compare two index maps. +func compareIndexes(a, b map[uint64]*list.Element) error { + for k1, e1 := range a { + e2, ok := b[k1] + if !ok { + return fmt.Errorf("element with key %d "+ + "not found in b", k1) + } + htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor) + if err := compareHtlcs(htlc1, htlc2); err != nil { + return err + } + } + + for k1, e1 := range b { + e2, ok := a[k1] + if !ok { + return fmt.Errorf("element with key %d not "+ + "found in a", k1) + } + htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor) + if err := compareHtlcs(htlc1, htlc2); err != nil { + return err + } + } + + return nil +} + +// compareLogs is a helper method to compare two updateLogs. +func compareLogs(a, b *updateLog) error { + if a.logIndex != b.logIndex { + return fmt.Errorf("log indexes don't match: %d vs %d", + a.logIndex, b.logIndex) + } + + if a.htlcCounter != b.htlcCounter { + return fmt.Errorf("htlc counters don't match: %d vs %d", + a.htlcCounter, b.htlcCounter) + } + + if err := compareIndexes(a.updateIndex, b.updateIndex); err != nil { + return fmt.Errorf("update indexes don't match: %v", err) + } + if err := compareIndexes(a.htlcIndex, b.htlcIndex); err != nil { + return fmt.Errorf("htlc indexes don't match: %v", err) + } + + if a.Len() != b.Len() { + return fmt.Errorf("list lengths not equal: %d vs %d", + a.Len(), b.Len()) + } + + e1, e2 := a.Front(), b.Front() + for ; e1 != nil; e1, e2 = e1.Next(), e2.Next() { + htlc1, htlc2 := e1.Value.(*PaymentDescriptor), e2.Value.(*PaymentDescriptor) + if err := compareHtlcs(htlc1, htlc2); err != nil { + return err + } + } + + return nil +} + +// TestChannelRestoreUpdateLogs makes sure we are able to properly restore the +// update logs in the case where a different number of HTLCs are locked in on +// the local, remote and pending remote commitment. +func TestChannelRestoreUpdateLogs(t *testing.T) { + t.Parallel() + + aliceChannel, bobChannel, cleanUp, err := CreateTestChannels() + if err != nil { + t.Fatalf("unable to create test channels: %v", err) + } + defer cleanUp() + + // First, we'll add an HTLC from Alice to Bob, which we will lock in on + // Bob's commit, but not on Alice's. + htlcAmount := lnwire.NewMSatFromSatoshis(20000) + htlcAlice, _ := createHTLC(0, htlcAmount) + if _, err := aliceChannel.AddHTLC(htlcAlice, nil); err != nil { + t.Fatalf("alice unable to add htlc: %v", err) + } + if _, err := bobChannel.ReceiveHTLC(htlcAlice); err != nil { + t.Fatalf("bob unable to recv add htlc: %v", err) + } + + // Let Alice sign a new state, which will include the HTLC just sent. + aliceSig, aliceHtlcSigs, err := aliceChannel.SignNextCommitment() + if err != nil { + t.Fatalf("unable to sign commitment: %v", err) + } + + // Bob receives this commitment signature, and revokes his old state. + err = bobChannel.ReceiveNewCommitment(aliceSig, aliceHtlcSigs) + if err != nil { + t.Fatalf("unable to receive commitment: %v", err) + } + bobRevocation, _, err := bobChannel.RevokeCurrentCommitment() + if err != nil { + t.Fatalf("unable to revoke commitment: %v", err) + } + + // When Alice now receives this revocation, she will advance her remote + // commitment chain to the commitment which includes the HTLC just + // sent. However her local commitment chain still won't include the + // state with the HTLC, since she hasn't received a new commitment + // signature from Bob yet. + _, _, _, err = aliceChannel.ReceiveRevocation(bobRevocation) + if err != nil { + t.Fatalf("unable to recive revocation: %v", err) + } + + // Now make Alice send and sign an additional HTLC. We don't let Bob + // receive it. We do this since we want to check that update logs are + // restored properly below, and we'll only restore updates that have + // been ACKed. + htlcAlice, _ = createHTLC(1, htlcAmount) + if _, err := aliceChannel.AddHTLC(htlcAlice, nil); err != nil { + t.Fatalf("alice unable to add htlc: %v", err) + } + + // Send the signature covering the HTLC. This is okay, since the local + // and remote commit chains are updated in an async fashion. Since the + // remote chain was updated with the latest state (since Bob sent the + // revocation earlier) we can keep advancing the remote commit chain. + aliceSig, aliceHtlcSigs, err = aliceChannel.SignNextCommitment() + if err != nil { + t.Fatalf("unable to sign commitment: %v", err) + } + + // After Alice has signed this commitment, her local commitment will + // contain no HTLCs, her remote commitment will contain an HTLC with + // index 0, and the pending remote commitment (a signed remote + // commitment which is not AKCed yet) will contain an additional HTLC + // with index 1. + + // We now re-create the channels, mimicking a restart. This should sync + // the update logs up to the correct state set up above. + newAliceChannel, err := NewLightningChannel( + aliceChannel.Signer, nil, aliceChannel.channelState, + ) + if err != nil { + t.Fatalf("unable to create new channel: %v", err) + } + defer newAliceChannel.Stop() + + newBobChannel, err := NewLightningChannel( + bobChannel.Signer, nil, bobChannel.channelState, + ) + if err != nil { + t.Fatalf("unable to create new channel: %v", err) + } + defer newBobChannel.Stop() + + // compare all the logs between the old and new channels, to make sure + // they all got restored properly. + err = compareLogs(aliceChannel.localUpdateLog, + newAliceChannel.localUpdateLog) + if err != nil { + t.Fatalf("alice local log not restored: %v", err) + } + + err = compareLogs(aliceChannel.remoteUpdateLog, + newAliceChannel.remoteUpdateLog) + if err != nil { + t.Fatalf("alice remote log not restored: %v", err) + } + + err = compareLogs(bobChannel.localUpdateLog, + newBobChannel.localUpdateLog) + if err != nil { + t.Fatalf("bob local log not restored: %v", err) + } + + err = compareLogs(bobChannel.remoteUpdateLog, + newBobChannel.remoteUpdateLog) + if err != nil { + t.Fatalf("bob remote log not restored: %v", err) + } +}