diff --git a/channeldb/channel.go b/channeldb/channel.go index 70722dfa..696cc4ed 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -43,6 +43,16 @@ var ( // funding flow. chanInfoKey = []byte("chan-info-key") + // localUpfrontShutdownKey can be accessed within the bucket for a channel + // (identified by its chanPoint). This key stores an optional upfront + // shutdown script for the local peer. + localUpfrontShutdownKey = []byte("local-upfront-shutdown-key") + + // remoteUpfrontShutdownKey can be accessed within the bucket for a channel + // (identified by its chanPoint). This key stores an optional upfront + // shutdown script for the remote peer. + remoteUpfrontShutdownKey = []byte("remote-upfront-shutdown-key") + // chanCommitmentKey can be accessed within the sub-bucket for a // particular channel. This key stores the up to date commitment state // for a particular channel party. Appending a 0 to the end of this key @@ -551,6 +561,16 @@ type OpenChannel struct { // method on the ChanType field. FundingTxn *wire.MsgTx + // LocalShutdownScript is set to a pre-set script if the channel was opened + // by the local node with option_upfront_shutdown_script set. If the option + // was not set, the field is empty. + LocalShutdownScript lnwire.DeliveryAddress + + // RemoteShutdownScript is set to a pre-set script if the channel was opened + // by the remote node with option_upfront_shutdown_script set. If the option + // was not set, the field is empty. + RemoteShutdownScript lnwire.DeliveryAddress + // TODO(roasbeef): eww Db *DB @@ -2573,7 +2593,60 @@ func putChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error { return err } - return chanBucket.Put(chanInfoKey, w.Bytes()) + if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil { + return err + } + + // Finally, add optional shutdown scripts for the local and remote peer if + // they are present. + if err := putOptionalUpfrontShutdownScript( + chanBucket, localUpfrontShutdownKey, channel.LocalShutdownScript, + ); err != nil { + return err + } + + return putOptionalUpfrontShutdownScript( + chanBucket, remoteUpfrontShutdownKey, channel.RemoteShutdownScript, + ) +} + +// putOptionalUpfrontShutdownScript adds a shutdown script under the key +// provided if it has a non-zero length. +func putOptionalUpfrontShutdownScript(chanBucket *bbolt.Bucket, key []byte, + script []byte) error { + // If the script is empty, we do not need to add anything. + if len(script) == 0 { + return nil + } + + var w bytes.Buffer + if err := WriteElement(&w, script); err != nil { + return err + } + + return chanBucket.Put(key, w.Bytes()) +} + +// getOptionalUpfrontShutdownScript reads the shutdown script stored under the +// key provided if it is present. Upfront shutdown scripts are optional, so the +// function returns with no error if the key is not present. +func getOptionalUpfrontShutdownScript(chanBucket *bbolt.Bucket, key []byte, + script *lnwire.DeliveryAddress) error { + + // Return early if the bucket does not exit, a shutdown script was not set. + bs := chanBucket.Get(key) + if bs == nil { + return nil + } + + var tempScript []byte + r := bytes.NewReader(bs) + if err := ReadElement(r, &tempScript); err != nil { + return err + } + *script = tempScript + + return nil } func serializeChanCommit(w io.Writer, c *ChannelCommitment) error { @@ -2696,7 +2769,16 @@ func fetchChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error { channel.Packager = NewChannelPackager(channel.ShortChannelID) - return nil + // Finally, read the optional shutdown scripts. + if err := getOptionalUpfrontShutdownScript( + chanBucket, localUpfrontShutdownKey, &channel.LocalShutdownScript, + ); err != nil { + return err + } + + return getOptionalUpfrontShutdownScript( + chanBucket, remoteUpfrontShutdownKey, &channel.RemoteShutdownScript, + ) } func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 645b3030..931f2879 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -345,6 +345,101 @@ func TestOpenChannelPutGetDelete(t *testing.T) { } } +// TestOptionalShutdown tests the reading and writing of channels with and +// without optional shutdown script fields. +func TestOptionalShutdown(t *testing.T) { + local := lnwire.DeliveryAddress([]byte("local shutdown script")) + remote := lnwire.DeliveryAddress([]byte("remote shutdown script")) + + if _, err := rand.Read(remote); err != nil { + t.Fatalf("Could not create random script: %v", err) + } + + tests := []struct { + name string + modifyChannel func(channel *OpenChannel) + expectedLocal lnwire.DeliveryAddress + expectedRemote lnwire.DeliveryAddress + }{ + { + name: "no shutdown scripts", + modifyChannel: func(channel *OpenChannel) {}, + }, + { + name: "local shutdown script", + modifyChannel: func(channel *OpenChannel) { + channel.LocalShutdownScript = local + }, + expectedLocal: local, + }, + { + name: "remote shutdown script", + modifyChannel: func(channel *OpenChannel) { + channel.RemoteShutdownScript = remote + }, + expectedRemote: remote, + }, + { + name: "both scripts set", + modifyChannel: func(channel *OpenChannel) { + channel.LocalShutdownScript = local + channel.RemoteShutdownScript = remote + }, + expectedLocal: local, + expectedRemote: remote, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // Create the test channel state, then add an additional fake HTLC + // before syncing to disk. + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + test.modifyChannel(state) + + // Write channels to Db. + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18556, + } + if err := state.SyncPending(addr, 101); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + openChannels, err := cdb.FetchOpenChannels(state.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch open channel: %v", err) + } + + if len(openChannels) != 1 { + t.Fatalf("Expected one channel open, got: %v", len(openChannels)) + } + + if !bytes.Equal(openChannels[0].LocalShutdownScript, test.expectedLocal) { + t.Fatalf("Expected local: %x, got: %x", test.expectedLocal, + openChannels[0].LocalShutdownScript) + } + + if !bytes.Equal(openChannels[0].RemoteShutdownScript, test.expectedRemote) { + t.Fatalf("Expected remote: %x, got: %x", test.expectedRemote, + openChannels[0].RemoteShutdownScript) + } + }) + } +} + func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { if !reflect.DeepEqual(a, b) { _, _, line, _ := runtime.Caller(1) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index a958dea8..f040aecd 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -4921,13 +4921,13 @@ func (lc *LightningChannel) ShortChanID() lnwire.ShortChannelID { // LocalUpfrontShutdownScript returns the local upfront shutdown script for the // channel. If it was not set, an empty byte array is returned. func (lc *LightningChannel) LocalUpfrontShutdownScript() lnwire.DeliveryAddress { - return nil + return lc.channelState.LocalShutdownScript } // RemoteUpfrontShutdownScript returns the remote upfront shutdown script for the // channel. If it was not set, an empty byte array is returned. func (lc *LightningChannel) RemoteUpfrontShutdownScript() lnwire.DeliveryAddress { - return nil + return lc.channelState.RemoteShutdownScript } // genHtlcScript generates the proper P2WSH public key scripts for the HTLC