diff --git a/lnwallet/channel.go b/lnwallet/channel.go index e6900f85..524b9a37 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -183,6 +183,47 @@ type commitment struct { // indexes. ourBalance btcutil.Amount theirBalance btcutil.Amount + + // htlcs is the set of HTLC's which remain uncleared within this + // commitment. + outgoingHTLCs []*PaymentDescriptor + incomingHTLCs []*PaymentDescriptor +} + +// toChannelDelta converts the target commitment into a format suitable to be +// written to disk after an accepted state transition. +// TODO(roasbeef): properly fill in refund timeouts +func (c *commitment) toChannelDelta() (*channeldb.ChannelDelta, error) { + numHtlcs := len(c.outgoingHTLCs) + len(c.incomingHTLCs) + delta := &channeldb.ChannelDelta{ + LocalBalance: c.ourBalance, + RemoteBalance: c.theirBalance, + UpdateNum: uint32(c.height), + Htlcs: make([]*channeldb.HTLC, 0, numHtlcs), + } + + for _, htlc := range c.outgoingHTLCs { + h := &channeldb.HTLC{ + Incoming: false, + Amt: htlc.Amount, + RHash: htlc.RHash, + RefundTimeout: htlc.Timeout, + RevocationDelay: 0, + } + delta.Htlcs = append(delta.Htlcs, h) + } + for _, htlc := range c.incomingHTLCs { + h := &channeldb.HTLC{ + Incoming: true, + Amt: htlc.Amount, + RHash: htlc.RHash, + RefundTimeout: htlc.Timeout, + RevocationDelay: 0, + } + delta.Htlcs = append(delta.Htlcs, h) + } + + return delta, nil } // commitmentChain represents a chain of unrevoked commitments. The tail of the @@ -386,6 +427,8 @@ func NewLightningChannel(signer Signer, wallet *LightningWallet, // Initialize both of our chains the current un-revoked commitment for // each side. + // TODO(roasbeef): add chnneldb.RevocationLogTail method, then init + // their commitment from that initialCommitment := &commitment{ height: lc.currentHeight, ourBalance: state.OurBalance, @@ -396,6 +439,12 @@ func NewLightningChannel(signer Signer, wallet *LightningWallet, lc.localCommitChain.addCommitment(initialCommitment) lc.remoteCommitChain.addCommitment(initialCommitment) + // If we're restarting from a channel with history, then restore the + // update in-memory update logs to that of the prior state. + if lc.currentHeight != 0 { + lc.restoreStateLogs() + } + // TODO(roasbeef): do a NotifySpent for the funding input, and // NotifyReceived for all commitment outputs. @@ -420,6 +469,54 @@ func NewLightningChannel(signer Signer, wallet *LightningWallet, return lc, nil } +// restoreStateLogs runs through the current locked-in HTLC's from the point of +// view of the channel and insert corresponding log entries (both local and +// remote) for each HTLC read from disk. This method is required sync the +// in-memory state of the state machine with that read from persistent storage. +func (lc *LightningChannel) restoreStateLogs() error { + var pastHeight uint64 + if lc.currentHeight > 0 { + pastHeight = lc.currentHeight - 1 + } + + var ourCounter, theirCounter uint32 + for _, htlc := range lc.channelState.Htlcs { + // TODO(roasbeef): set isForwarded to false for all? need to + // persist state w.r.t to if forwarded or not, or can + // inadvertenly trigger replays + pd := &PaymentDescriptor{ + RHash: htlc.RHash, + Timeout: htlc.RefundTimeout, + Amount: htlc.Amt, + EntryType: Add, + addCommitHeightRemote: pastHeight, + addCommitHeightLocal: pastHeight, + } + + if !htlc.Incoming { + pd.Index = ourCounter + lc.ourLogIndex[pd.Index] = lc.ourUpdateLog.PushBack(pd) + + ourCounter++ + } else { + pd.Index = theirCounter + lc.theirLogIndex[pd.Index] = lc.theirUpdateLog.PushBack(pd) + + theirCounter++ + } + } + + lc.ourLogCounter = ourCounter + lc.theirLogCounter = theirCounter + + lc.localCommitChain.tail().ourMessageIndex = ourCounter + lc.localCommitChain.tail().theirMessageIndex = theirCounter + lc.remoteCommitChain.tail().ourMessageIndex = ourCounter + lc.remoteCommitChain.tail().theirMessageIndex = theirCounter + + return nil +} + type htlcView struct { ourUpdates []*PaymentDescriptor theirUpdates []*PaymentDescriptor @@ -547,6 +644,8 @@ func (lc *LightningChannel) fetchCommitmentView(remoteChain bool, ourMessageIndex: ourLogIndex, theirMessageIndex: theirLogIndex, theirBalance: theirBalance, + outgoingHTLCs: filteredHTLCView.ourUpdates, + incomingHTLCs: filteredHTLCView.theirUpdates, }, nil } @@ -886,26 +985,25 @@ func (lc *LightningChannel) RevokeCurrentCommitment() (*lnwire.CommitRevocation, // Advance our tail, as we've revoked our previous state. lc.localCommitChain.advanceTail() - lc.currentHeight++ + // Additionally, generate a channel delta for this state transition for + // persistent storage. // TODO(roasbeef): update sent/received. tail := lc.localCommitChain.tail() - lc.channelState.OurCommitTx = tail.txn - lc.channelState.OurBalance = tail.ourBalance - lc.channelState.TheirBalance = tail.theirBalance - lc.channelState.OurCommitSig = tail.sig - lc.channelState.NumUpdates++ + delta, err := tail.toChannelDelta() + if err != nil { + return nil, err + } + err = lc.channelState.UpdateCommitment(tail.txn, tail.sig, delta) + if err != nil { + return nil, err + } walletLog.Tracef("ChannelPoint(%v): state transition accepted: "+ "our_balance=%v, their_balance=%v", lc.channelState.ChanID, tail.ourBalance, tail.theirBalance) - // TODO(roasbeef): use RecordChannelDelta once fin - if err := lc.channelState.FullSync(); err != nil { - return nil, err - } - revocationMsg.ChannelPoint = lc.channelState.ChanID return revocationMsg, nil } @@ -974,7 +1072,12 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.CommitRevocation) ( // the current revocation key+hash for the remote party. Therefore we // sync now to ensure the elkrem receiver state is consistent with the // current commitment height. - if err := lc.channelState.SyncRevocation(); err != nil { + tail := lc.remoteCommitChain.tail() + delta, err := tail.toChannelDelta() + if err != nil { + return nil, err + } + if err := lc.channelState.AppendToRevocationLog(delta); err != nil { return nil, err } @@ -1097,6 +1200,8 @@ func (lc *LightningChannel) ExtendRevocationWindow() (*lnwire.CommitRevocation, // AddHTLC adds an HTLC to the state machine's local update log. This method // should be called when preparing to send an outgoing HTLC. +// TODO(roasbeef): check for duplicates below? edge case during restart w/ HTLC +// persistence func (lc *LightningChannel) AddHTLC(htlc *lnwire.HTLCAddRequest) uint32 { pd := &PaymentDescriptor{ EntryType: Add, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 588ecbfd..2df97f49 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -2,6 +2,7 @@ package lnwallet import ( "bytes" + "fmt" "io/ioutil" "os" "testing" @@ -68,14 +69,83 @@ func (m *mockSigner) SignOutputRaw(tx *wire.MsgTx, signDesc *SignDescriptor) ([] return sig[:len(sig)-1], nil } -// ComputeInputScript... func (m *mockSigner) ComputeInputScript(tx *wire.MsgTx, signDesc *SignDescriptor) (*InputScript, error) { return nil, nil } -// createTestChannels creates two test channels funded with 10 BTC, with 5 BTC +// initRevocationWindows simulates a new channel being opened within the p2p +// network by populating the initial revocation windows of the passed +// commitment state machines. +func initRevocationWindows(chanA, chanB *LightningChannel, windowSize int) error { + for i := 0; i < windowSize; i++ { + aliceNextRevoke, err := chanA.ExtendRevocationWindow() + if err != nil { + return err + } + if htlcs, err := chanB.ReceiveRevocation(aliceNextRevoke); err != nil { + return err + } else if htlcs != nil { + return err + } + + bobNextRevoke, err := chanB.ExtendRevocationWindow() + if err != nil { + return err + } + if htlcs, err := chanA.ReceiveRevocation(bobNextRevoke); err != nil { + return err + } else if htlcs != nil { + return err + } + } + + return nil +} + +// forceStateTransition executes the neccessary interaction between the two +// commitment state machines to transition to a new state locking in any +// pending updates. +func forceStateTransition(chanA, chanB *LightningChannel) error { + aliceSig, bobIndex, err := chanA.SignNextCommitment() + if err != nil { + return err + } + if err := chanB.ReceiveNewCommitment(aliceSig, bobIndex); err != nil { + fmt.Println("alice sig invalid") + return err + } + + bobSig, aliceIndex, err := chanB.SignNextCommitment() + if err != nil { + return err + } + bobRevocation, err := chanB.RevokeCurrentCommitment() + if err != nil { + return err + } + + if err := chanA.ReceiveNewCommitment(bobSig, aliceIndex); err != nil { + fmt.Println("bob sig invalid") + return err + } + aliceRevocation, err := chanA.RevokeCurrentCommitment() + if err != nil { + return err + } + + if _, err := chanA.ReceiveRevocation(bobRevocation); err != nil { + return err + } + if _, err := chanB.ReceiveRevocation(aliceRevocation); err != nil { + return err + } + + return nil +} + +// createTestChannels creates two test channels funded witr 10 BTC, with 5 BTC // allocated to each side. -func createTestChannels() (*LightningChannel, *LightningChannel, func(), error) { +func createTestChannels(revocationWindow int) (*LightningChannel, *LightningChannel, func(), error) { aliceKeyPriv, aliceKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), testWalletPrivKey) bobKeyPriv, bobKeyPub := btcec.PrivKeyFromBytes(btcec.S256(), @@ -193,6 +263,13 @@ func createTestChannels() (*LightningChannel, *LightningChannel, func(), error) return nil, nil, nil, err } + // Now that the channel are open, simulate the start of a session by + // having Alice and Bob extend their revocation windows to each other. + err = initRevocationWindows(channelAlice, channelBob, revocationWindow) + if err != nil { + return nil, nil, nil, err + } + return channelAlice, channelBob, cleanUpFunc, nil } @@ -208,39 +285,12 @@ func TestSimpleAddSettleWorkflow(t *testing.T) { // Create a test channel which will be used for the duration of this // unittest. The channel will be funded evenly with Alice having 5 BTC, // and Bob having 5 BTC. - aliceChannel, bobChannel, cleanUp, err := createTestChannels() + aliceChannel, bobChannel, cleanUp, err := createTestChannels(3) if err != nil { t.Fatalf("unable to create test channels: %v", err) } defer cleanUp() - // Now that the channel are open, simulate the start of a session by - // having Alice and Bob extend their revocation windows to each other. - // For testing purposes we'll use a revocation window of size 3. - for i := 1; i < 4; i++ { - aliceNextRevoke, err := aliceChannel.ExtendRevocationWindow() - if err != nil { - t.Fatalf("unable to create new alice revoke") - } - if htlcs, err := bobChannel.ReceiveRevocation(aliceNextRevoke); err != nil { - t.Fatalf("bob unable to process alice revocation increment: %v", err) - } else if htlcs != nil { - t.Fatalf("revocation window extend should not trigger htlc "+ - "forward, instead %v marked for forwarding", spew.Sdump(htlcs)) - } - - bobNextRevoke, err := bobChannel.ExtendRevocationWindow() - if err != nil { - t.Fatalf("unable to create new bob revoke") - } - if htlcs, err := aliceChannel.ReceiveRevocation(bobNextRevoke); err != nil { - t.Fatalf("bob unable to process alice revocation increment: %v", err) - } else if htlcs != nil { - t.Fatalf("revocation window extend should not trigger htlc "+ - "forward, instead %v marked for forwarding", spew.Sdump(htlcs)) - } - } - // The edge of the revocation window for both sides should be 3 at this // point. if aliceChannel.revocationWindowEdge != 3 { @@ -484,7 +534,7 @@ func TestCooperativeChannelClosure(t *testing.T) { // Create a test channel which will be used for the duration of this // unittest. The channel will be funded evenly with Alice having 5 BTC, // and Bob having 5 BTC. - aliceChannel, bobChannel, cleanUp, err := createTestChannels() + aliceChannel, bobChannel, cleanUp, err := createTestChannels(3) if err != nil { t.Fatalf("unable to create test channels: %v", err) } @@ -526,3 +576,186 @@ func TestCooperativeChannelClosure(t *testing.T) { aliceCloseSha[:], txid[:]) } } + +func TestStateUpdatePersistence(t *testing.T) { + // Create a test channel which will be used for the duration of this + // unittest. The channel will be funded evenly with Alice having 5 BTC, + // and Bob having 5 BTC. + aliceChannel, bobChannel, cleanUp, err := createTestChannels(3) + if err != nil { + t.Fatalf("unable to create test channels: %v", err) + } + defer cleanUp() + + if err := aliceChannel.channelState.FullSync(); err != nil { + t.Fatalf("unable to sync alice's channel: %v", err) + } + if err := bobChannel.channelState.FullSync(); err != nil { + t.Fatalf("unable to sync bob's channel: %v", err) + } + + aliceStartingBalance := aliceChannel.channelState.OurBalance + bobStartingBalance := bobChannel.channelState.OurBalance + + const numHtlcs = 4 + + // Alice adds 3 HTLC's to the update log, while Bob adds a single HTLC. + var alicePreimage [32]byte + copy(alicePreimage[:], bytes.Repeat([]byte{0xaa}, 32)) + var bobPreimage [32]byte + copy(bobPreimage[:], bytes.Repeat([]byte{0xbb}, 32)) + for i := 0; i < 3; i++ { + rHash := fastsha256.Sum256(alicePreimage[:]) + h := &lnwire.HTLCAddRequest{ + RedemptionHashes: [][32]byte{rHash}, + Amount: lnwire.CreditsAmount(1000), + Expiry: uint32(10), + } + + aliceChannel.AddHTLC(h) + bobChannel.ReceiveHTLC(h) + } + rHash := fastsha256.Sum256(bobPreimage[:]) + bobh := &lnwire.HTLCAddRequest{ + RedemptionHashes: [][32]byte{rHash}, + Amount: lnwire.CreditsAmount(1000), + Expiry: uint32(10), + } + bobChannel.AddHTLC(bobh) + aliceChannel.ReceiveHTLC(bobh) + + // Next, Alice initiates a state transition to lock in the above HTLC's. + if err := forceStateTransition(aliceChannel, bobChannel); err != nil { + t.Fatalf("unable to lock in HTLC's: %v", err) + } + + // The balances of both channels should be updated accordingly. + aliceBalance := aliceChannel.channelState.OurBalance + expectedAliceBalance := aliceStartingBalance - btcutil.Amount(3000) + bobBalance := bobChannel.channelState.OurBalance + expectedBobBalance := bobStartingBalance - btcutil.Amount(1000) + if aliceBalance != expectedAliceBalance { + t.Fatalf("expected %v alice balance, got %v", expectedAliceBalance, + aliceBalance) + } + if bobBalance != expectedBobBalance { + t.Fatalf("expected %v bob balance, got %v", expectedBobBalance, + bobBalance) + } + + // The latest commitment from both sides should have all the HTLC's. + numAliceOutgoing := aliceChannel.localCommitChain.tail().outgoingHTLCs + numAliceIncoming := aliceChannel.localCommitChain.tail().incomingHTLCs + if len(numAliceOutgoing) != 3 { + t.Fatalf("expected %v htlcs, instead got %v", 3, numAliceOutgoing) + } + if len(numAliceIncoming) != 1 { + t.Fatalf("expected %v htlcs, instead got %v", 1, numAliceIncoming) + } + numBobOutgoing := bobChannel.localCommitChain.tail().outgoingHTLCs + numBobIncoming := bobChannel.localCommitChain.tail().incomingHTLCs + if len(numBobOutgoing) != 1 { + t.Fatalf("expected %v htlcs, instead got %v", 1, numBobOutgoing) + } + if len(numBobIncoming) != 3 { + t.Fatalf("expected %v htlcs, instead got %v", 3, numBobIncoming) + } + + // Now fetch both of the channels created above from disk to simulate a + // node restart with persistence. + id := wire.ShaHash(testHdSeed) + aliceChannels, err := aliceChannel.channelState.Db.FetchOpenChannels(&id) + if err != nil { + t.Fatalf("unable to fetch channel: %v", err) + } + bobChannels, err := bobChannel.channelState.Db.FetchOpenChannels(&id) + if err != nil { + t.Fatalf("unable to fetch channel: %v", err) + } + aliceChannelNew, err := NewLightningChannel(aliceChannel.signer, nil, nil, aliceChannels[0]) + if err != nil { + t.Fatalf("unable to create new channel: %v", err) + } + bobChannelNew, err := NewLightningChannel(bobChannel.signer, nil, nil, bobChannels[0]) + if err != nil { + t.Fatalf("unable to create new channel: %v", err) + } + if err := initRevocationWindows(aliceChannelNew, bobChannelNew, 3); err != nil { + t.Fatalf("unable to init revocation windows: %v", err) + } + + // The state update logs of the new channels and the old channels + // should now be identical other than the height the HTLC's were added. + if aliceChannel.ourLogCounter != aliceChannelNew.ourLogCounter { + t.Fatalf("alice log counter: expected %v, got %v", + aliceChannel.ourLogCounter, aliceChannelNew.ourLogCounter) + } + if aliceChannel.theirLogCounter != aliceChannelNew.theirLogCounter { + t.Fatalf("alice log counter: expected %v, got %v", + aliceChannel.theirLogCounter, aliceChannelNew.theirLogCounter) + } + if aliceChannel.ourUpdateLog.Len() != aliceChannelNew.ourUpdateLog.Len() { + t.Fatalf("alice log len: expected %v, got %v", + aliceChannel.ourUpdateLog.Len(), + aliceChannelNew.ourUpdateLog.Len()) + } + if aliceChannel.theirUpdateLog.Len() != aliceChannelNew.theirUpdateLog.Len() { + t.Fatalf("alice log len: expected %v, got %v", + aliceChannel.theirUpdateLog.Len(), + aliceChannelNew.theirUpdateLog.Len()) + } + if bobChannel.ourLogCounter != bobChannelNew.ourLogCounter { + t.Fatalf("bob log counter: expected %v, got %v", + bobChannel.ourLogCounter, bobChannelNew.ourLogCounter) + } + if bobChannel.theirLogCounter != bobChannelNew.theirLogCounter { + t.Fatalf("bob log counter: expected %v, got %v", + bobChannel.theirLogCounter, bobChannelNew.theirLogCounter) + } + if bobChannel.ourUpdateLog.Len() != bobChannelNew.ourUpdateLog.Len() { + t.Fatalf("bob log len: expected %v, got %v", + bobChannelNew.ourUpdateLog.Len(), bobChannelNew.ourUpdateLog.Len()) + } + if bobChannel.theirUpdateLog.Len() != bobChannelNew.theirUpdateLog.Len() { + t.Fatalf("bob log len: expected %v, got %v", + bobChannel.theirUpdateLog.Len(), bobChannelNew.theirUpdateLog.Len()) + } + + // Now settle all the HTLC's, then force a state update. The state + // update should suceed as both sides have identical. + for i := 0; i < 3; i++ { + settleIndex, err := bobChannelNew.SettleHTLC(alicePreimage) + if err != nil { + t.Fatalf("unable to settle htlc: %v", err) + } + err = aliceChannelNew.ReceiveHTLCSettle(alicePreimage, settleIndex) + if err != nil { + t.Fatalf("unable to settle htlc: %v", err) + } + } + settleIndex, err := aliceChannelNew.SettleHTLC(bobPreimage) + if err != nil { + t.Fatalf("unable to settle htlc: %v", err) + } + err = bobChannelNew.ReceiveHTLCSettle(bobPreimage, settleIndex) + if err != nil { + t.Fatalf("unable to settle htlc: %v", err) + } + if err := forceStateTransition(aliceChannelNew, bobChannelNew); err != nil { + t.Fatalf("unable to update commitments: %v", err) + } + + // The balances of both sides should have been updated accordingly. + aliceBalance = aliceChannelNew.channelState.OurBalance + expectedAliceBalance = aliceStartingBalance - btcutil.Amount(2000) + bobBalance = bobChannelNew.channelState.OurBalance + expectedBobBalance = bobStartingBalance + btcutil.Amount(2000) + if aliceBalance != expectedAliceBalance { + t.Fatalf("expected %v alice balance, got %v", expectedAliceBalance, + aliceBalance) + } + if bobBalance != expectedBobBalance { + t.Fatalf("expected %v bob balance, got %v", expectedBobBalance, + bobBalance) + } +}