Merge pull request #2370 from Roasbeef/static-chan-backups-chanbackup
chanbackup: add new package implementing static channel backups
This commit is contained in:
commit
e9889cb899
99
chanbackup/backup.go
Normal file
99
chanbackup/backup.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/btcsuite/btcd/wire"
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LiveChannelSource is an interface that allows us to query for the set of
|
||||||
|
// live channels. A live channel is one that is open, and has not had a
|
||||||
|
// commitment transaction broadcast.
|
||||||
|
type LiveChannelSource interface {
|
||||||
|
// FetchAllChannels returns all known live channels.
|
||||||
|
FetchAllChannels() ([]*channeldb.OpenChannel, error)
|
||||||
|
|
||||||
|
// FetchChannel attempts to locate a live channel identified by the
|
||||||
|
// passed chanPoint.
|
||||||
|
FetchChannel(chanPoint wire.OutPoint) (*channeldb.OpenChannel, error)
|
||||||
|
|
||||||
|
// AddrsForNode returns all known addresses for the target node public
|
||||||
|
// key.
|
||||||
|
AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// assembleChanBackup attempts to assemble a static channel backup for the
|
||||||
|
// passed open channel. The backup includes all information required to restore
|
||||||
|
// the channel, as well as addressing information so we can find the peer and
|
||||||
|
// reconnect to them to initiate the protocol.
|
||||||
|
func assembleChanBackup(chanSource LiveChannelSource,
|
||||||
|
openChan *channeldb.OpenChannel) (*Single, error) {
|
||||||
|
|
||||||
|
log.Debugf("Crafting backup for ChannelPoint(%v)",
|
||||||
|
openChan.FundingOutpoint)
|
||||||
|
|
||||||
|
// First, we'll query the channel source to obtain all the addresses
|
||||||
|
// that are are associated with the peer for this channel.
|
||||||
|
nodeAddrs, err := chanSource.AddrsForNode(openChan.IdentityPub)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
single := NewSingle(openChan, nodeAddrs)
|
||||||
|
|
||||||
|
return &single, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchBackupForChan attempts to create a plaintext static channel backup for
|
||||||
|
// the target channel identified by its channel point. If we're unable to find
|
||||||
|
// the target channel, then an error will be returned.
|
||||||
|
func FetchBackupForChan(chanPoint wire.OutPoint,
|
||||||
|
chanSource LiveChannelSource) (*Single, error) {
|
||||||
|
|
||||||
|
// First, we'll query the channel source to see if the channel is known
|
||||||
|
// and open within the database.
|
||||||
|
targetChan, err := chanSource.FetchChannel(chanPoint)
|
||||||
|
if err != nil {
|
||||||
|
// If we can't find the channel, then we return with an error,
|
||||||
|
// as we have nothing to backup.
|
||||||
|
return nil, fmt.Errorf("unable to find target channel")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Once we have the target channel, we can assemble the backup using
|
||||||
|
// the source to obtain any extra information that we may need.
|
||||||
|
staticChanBackup, err := assembleChanBackup(chanSource, targetChan)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to create chan backup: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return staticChanBackup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchStaticChanBackups will return a plaintext static channel back up for
|
||||||
|
// all known active/open channels within the passed channel source.
|
||||||
|
func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) {
|
||||||
|
// First, we'll query the backup source for information concerning all
|
||||||
|
// currently open and available channels.
|
||||||
|
openChans, err := chanSource.FetchAllChannels()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we have all the channels, we'll use the chanSource to
|
||||||
|
// obtain any auxiliary information we need to craft a backup for each
|
||||||
|
// channel.
|
||||||
|
staticChanBackups := make([]Single, len(openChans))
|
||||||
|
for i, openChan := range openChans {
|
||||||
|
chanBackup, err := assembleChanBackup(chanSource, openChan)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
staticChanBackups[i] = *chanBackup
|
||||||
|
}
|
||||||
|
|
||||||
|
return staticChanBackups, nil
|
||||||
|
}
|
197
chanbackup/backup_test.go
Normal file
197
chanbackup/backup_test.go
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/btcsuite/btcd/wire"
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockChannelSource struct {
|
||||||
|
chans map[wire.OutPoint]*channeldb.OpenChannel
|
||||||
|
|
||||||
|
failQuery bool
|
||||||
|
|
||||||
|
addrs map[[33]byte][]net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockChannelSource() *mockChannelSource {
|
||||||
|
return &mockChannelSource{
|
||||||
|
chans: make(map[wire.OutPoint]*channeldb.OpenChannel),
|
||||||
|
addrs: make(map[[33]byte][]net.Addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockChannelSource) FetchAllChannels() ([]*channeldb.OpenChannel, error) {
|
||||||
|
if m.failQuery {
|
||||||
|
return nil, fmt.Errorf("fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
chans := make([]*channeldb.OpenChannel, 0, len(m.chans))
|
||||||
|
for _, channel := range m.chans {
|
||||||
|
chans = append(chans, channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
return chans, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockChannelSource) FetchChannel(chanPoint wire.OutPoint) (*channeldb.OpenChannel, error) {
|
||||||
|
if m.failQuery {
|
||||||
|
return nil, fmt.Errorf("fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
channel, ok := m.chans[chanPoint]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("can't find chan")
|
||||||
|
}
|
||||||
|
|
||||||
|
return channel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockChannelSource) addAddrsForNode(nodePub *btcec.PublicKey, addrs []net.Addr) {
|
||||||
|
var nodeKey [33]byte
|
||||||
|
copy(nodeKey[:], nodePub.SerializeCompressed())
|
||||||
|
|
||||||
|
m.addrs[nodeKey] = addrs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
|
||||||
|
if m.failQuery {
|
||||||
|
return nil, fmt.Errorf("fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
var nodeKey [33]byte
|
||||||
|
copy(nodeKey[:], nodePub.SerializeCompressed())
|
||||||
|
|
||||||
|
addrs, ok := m.addrs[nodeKey]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("can't find addr")
|
||||||
|
}
|
||||||
|
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchBackupForChan tests that we're able to construct a single channel
|
||||||
|
// backup for channels that are known, unknown, and also channels in which we
|
||||||
|
// can find addresses for and otherwise.
|
||||||
|
func TestFetchBackupForChan(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// First, we'll make two channels, only one of them will have all the
|
||||||
|
// information we need to construct set of backups for them.
|
||||||
|
randomChan1, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to generate chan: %v", err)
|
||||||
|
}
|
||||||
|
randomChan2, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to generate chan: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chanSource := newMockChannelSource()
|
||||||
|
chanSource.chans[randomChan1.FundingOutpoint] = randomChan1
|
||||||
|
chanSource.chans[randomChan2.FundingOutpoint] = randomChan2
|
||||||
|
|
||||||
|
chanSource.addAddrsForNode(randomChan1.IdentityPub, []net.Addr{addr1})
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
chanPoint wire.OutPoint
|
||||||
|
|
||||||
|
pass bool
|
||||||
|
}{
|
||||||
|
// Able to find channel, and addresses, should pass.
|
||||||
|
{
|
||||||
|
chanPoint: randomChan1.FundingOutpoint,
|
||||||
|
pass: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Able to find channel, not able to find addrs, should fail.
|
||||||
|
{
|
||||||
|
chanPoint: randomChan2.FundingOutpoint,
|
||||||
|
pass: false,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Not able to find channel, should fail.
|
||||||
|
{
|
||||||
|
chanPoint: op,
|
||||||
|
pass: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, testCase := range testCases {
|
||||||
|
_, err := FetchBackupForChan(testCase.chanPoint, chanSource)
|
||||||
|
switch {
|
||||||
|
// If this is a valid test case, and we failed, then we'll
|
||||||
|
// return an error.
|
||||||
|
case err != nil && testCase.pass:
|
||||||
|
t.Fatalf("#%v, unable to make chan backup: %v", i, err)
|
||||||
|
|
||||||
|
// If this is an invalid test case, and we passed it, then
|
||||||
|
// we'll return an error.
|
||||||
|
case err == nil && !testCase.pass:
|
||||||
|
t.Fatalf("#%v got nil error for invalid req: %v",
|
||||||
|
i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFetchStaticChanBackups tests that we're able to properly query the
|
||||||
|
// channel source for all channels and construct a Single for each channel.
|
||||||
|
func TestFetchStaticChanBackups(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// First, we'll make the set of channels that we want to seed the
|
||||||
|
// channel source with. Both channels will be fully populated in the
|
||||||
|
// channel source.
|
||||||
|
const numChans = 2
|
||||||
|
randomChan1, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to generate chan: %v", err)
|
||||||
|
}
|
||||||
|
randomChan2, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to generate chan: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chanSource := newMockChannelSource()
|
||||||
|
chanSource.chans[randomChan1.FundingOutpoint] = randomChan1
|
||||||
|
chanSource.chans[randomChan2.FundingOutpoint] = randomChan2
|
||||||
|
chanSource.addAddrsForNode(randomChan1.IdentityPub, []net.Addr{addr1})
|
||||||
|
chanSource.addAddrsForNode(randomChan2.IdentityPub, []net.Addr{addr2})
|
||||||
|
|
||||||
|
// With the channel source populated, we'll now attempt to create a set
|
||||||
|
// of backups for all the channels. This should succeed, as all items
|
||||||
|
// are populated within the channel source.
|
||||||
|
backups, err := FetchStaticChanBackups(chanSource)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create chan back ups: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(backups) != numChans {
|
||||||
|
t.Fatalf("expected %v chans, instead got %v", numChans,
|
||||||
|
len(backups))
|
||||||
|
}
|
||||||
|
|
||||||
|
// We'll attempt to create a set up backups again, but this time the
|
||||||
|
// second channel will have missing information, which should cause the
|
||||||
|
// query to fail.
|
||||||
|
var n [33]byte
|
||||||
|
copy(n[:], randomChan2.IdentityPub.SerializeCompressed())
|
||||||
|
delete(chanSource.addrs, n)
|
||||||
|
|
||||||
|
_, err = FetchStaticChanBackups(chanSource)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("query with incomplete information should fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
// To wrap up, we'll ensure that if we're unable to query the channel
|
||||||
|
// source at all, then we'll fail as well.
|
||||||
|
chanSource = newMockChannelSource()
|
||||||
|
chanSource.failQuery = true
|
||||||
|
_, err = FetchStaticChanBackups(chanSource)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("query should fail")
|
||||||
|
}
|
||||||
|
}
|
160
chanbackup/backupfile.go
Normal file
160
chanbackup/backupfile.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultBackupFileName is the default name of the auto updated static
|
||||||
|
// channel backup fie.
|
||||||
|
DefaultBackupFileName = "channel.backup"
|
||||||
|
|
||||||
|
// DefaultTempBackupFileName is the default name of the temporary SCB
|
||||||
|
// file that we'll use to atomically update the primary back up file
|
||||||
|
// when new channel are detected.
|
||||||
|
DefaultTempBackupFileName = "temp-dont-use.backup"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNoBackupFileExists is returned if caller attempts to call
|
||||||
|
// UpdateAndSwap with the file name not set.
|
||||||
|
ErrNoBackupFileExists = fmt.Errorf("back up file name not set")
|
||||||
|
|
||||||
|
// ErrNoTempBackupFile is returned if caller attempts to call
|
||||||
|
// UpdateAndSwap with the temp back up file name not set.
|
||||||
|
ErrNoTempBackupFile = fmt.Errorf("temp backup file not set")
|
||||||
|
)
|
||||||
|
|
||||||
|
// MultiFile represents a file on disk that a caller can use to read the packed
|
||||||
|
// multi backup into an unpacked one, and also atomically update the contents
|
||||||
|
// on disk once new channels have been opened, and old ones closed. This struct
|
||||||
|
// relies on an atomic file rename property which most widely use file systems
|
||||||
|
// have.
|
||||||
|
type MultiFile struct {
|
||||||
|
// fileName is the file name of the main back up file.
|
||||||
|
fileName string
|
||||||
|
|
||||||
|
// mainFile is an open handle to the main back up file.
|
||||||
|
mainFile *os.File
|
||||||
|
|
||||||
|
// tempFileName is the name of the file that we'll use to stage a new
|
||||||
|
// packed multi-chan backup, and the rename to the main back up file.
|
||||||
|
tempFileName string
|
||||||
|
|
||||||
|
// tempFile is an open handle to the temp back up file.
|
||||||
|
tempFile *os.File
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMultiFile create a new multi-file instance at the target location on the
|
||||||
|
// file system.
|
||||||
|
func NewMultiFile(fileName string) *MultiFile {
|
||||||
|
|
||||||
|
// We'll our temporary backup file in the very same directory as the
|
||||||
|
// main backup file.
|
||||||
|
backupFileDir := filepath.Dir(fileName)
|
||||||
|
tempFileName := filepath.Join(
|
||||||
|
backupFileDir, DefaultTempBackupFileName,
|
||||||
|
)
|
||||||
|
|
||||||
|
return &MultiFile{
|
||||||
|
fileName: fileName,
|
||||||
|
tempFileName: tempFileName,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAndSwap will attempt write a new temporary backup file to disk with
|
||||||
|
// the newBackup encoded, then atomically swap (via rename) the old file for
|
||||||
|
// the new file by updating the name of the new file to the old.
|
||||||
|
func (b *MultiFile) UpdateAndSwap(newBackup PackedMulti) error {
|
||||||
|
// If the main backup file isn't set, then we can't proceed.
|
||||||
|
if b.fileName == "" {
|
||||||
|
return ErrNoBackupFileExists
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the old back up file still exists, then we'll delete it before
|
||||||
|
// proceeding.
|
||||||
|
if _, err := os.Stat(b.tempFileName); err == nil {
|
||||||
|
log.Infof("Found old temp backup @ %v, removing before swap",
|
||||||
|
b.tempFileName)
|
||||||
|
|
||||||
|
err = os.Remove(b.tempFileName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to remove temp "+
|
||||||
|
"backup file: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we know the staging area is clear, we'll create the new
|
||||||
|
// temporary back up file.
|
||||||
|
var err error
|
||||||
|
b.tempFile, err = os.Create(b.tempFileName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// With the file created, we'll write the new packed multi backup and
|
||||||
|
// remove the temporary file all together once this method exits.
|
||||||
|
_, err = b.tempFile.Write([]byte(newBackup))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := b.tempFile.Sync(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer os.Remove(b.tempFileName)
|
||||||
|
|
||||||
|
log.Infof("Swapping old multi backup file from %v to %v",
|
||||||
|
b.tempFileName, b.fileName)
|
||||||
|
|
||||||
|
// Finally, we'll attempt to atomically rename the temporary file to
|
||||||
|
// the main back up file. If this succeeds, then we'll only have a
|
||||||
|
// single file on disk once this method exits.
|
||||||
|
return os.Rename(b.tempFileName, b.fileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractMulti attempts to extract the packed multi backup we currently point
|
||||||
|
// to into an unpacked version. This method will fail if no backup file
|
||||||
|
// currently exists as the specified location.
|
||||||
|
func (b *MultiFile) ExtractMulti(keyChain keychain.KeyRing) (*Multi, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// If the backup file isn't already set, then we'll attempt to open it
|
||||||
|
// anew.
|
||||||
|
if b.mainFile == nil {
|
||||||
|
// We'll return an error if the main file isn't currently set.
|
||||||
|
if b.fileName == "" {
|
||||||
|
return nil, ErrNoBackupFileExists
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, we'll open the file to prep for reading the
|
||||||
|
// contents.
|
||||||
|
b.mainFile, err = os.Open(b.fileName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Before we start to read the file, we'll ensure that the next read
|
||||||
|
// call will start from the front of the file.
|
||||||
|
_, err = b.mainFile.Seek(0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// With our seek successful, we'll now attempt to read the contents of
|
||||||
|
// the entire file in one swoop.
|
||||||
|
multiBytes, err := ioutil.ReadAll(b.mainFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, we'll attempt to unpack the file and return the unpack
|
||||||
|
// version to the caller.
|
||||||
|
packedMulti := PackedMulti(multiBytes)
|
||||||
|
return packedMulti.Unpack(keyChain)
|
||||||
|
}
|
289
chanbackup/backupfile_test.go
Normal file
289
chanbackup/backupfile_test.go
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeFakePackedMulti() (PackedMulti, error) {
|
||||||
|
newPackedMulti := make([]byte, 50)
|
||||||
|
if _, err := rand.Read(newPackedMulti[:]); err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to make test backup: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return PackedMulti(newPackedMulti), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertBackupMatches(t *testing.T, filePath string,
|
||||||
|
currentBackup PackedMulti) {
|
||||||
|
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
packedBackup, err := ioutil.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(packedBackup, currentBackup) {
|
||||||
|
t.Fatalf("backups don't match after first swap: "+
|
||||||
|
"expected %x got %x", packedBackup[:],
|
||||||
|
currentBackup)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertFileDeleted(t *testing.T, filePath string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
_, err := os.Stat(filePath)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("file %v still exists: ", filePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUpdateAndSwap test that we're able to properly swap out old backups on
|
||||||
|
// disk with new ones. Additionally, after a swap operation succeeds, then each
|
||||||
|
// time we should only have the main backup file on disk, as the temporary file
|
||||||
|
// has been removed.
|
||||||
|
func TestUpdateAndSwap(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tempTestDir, err := ioutil.TempDir("", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to make temp dir: %v", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tempTestDir)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
fileName string
|
||||||
|
tempFileName string
|
||||||
|
|
||||||
|
oldTempExists bool
|
||||||
|
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
// Main file name is blank, should fail.
|
||||||
|
{
|
||||||
|
fileName: "",
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Old temporary file still exists, should be removed. Only one
|
||||||
|
// file should remain.
|
||||||
|
{
|
||||||
|
fileName: filepath.Join(
|
||||||
|
tempTestDir, DefaultBackupFileName,
|
||||||
|
),
|
||||||
|
tempFileName: filepath.Join(
|
||||||
|
tempTestDir, DefaultTempBackupFileName,
|
||||||
|
),
|
||||||
|
oldTempExists: true,
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Old temp doesn't exist, should swap out file, only a single
|
||||||
|
// file remains.
|
||||||
|
{
|
||||||
|
fileName: filepath.Join(
|
||||||
|
tempTestDir, DefaultBackupFileName,
|
||||||
|
),
|
||||||
|
tempFileName: filepath.Join(
|
||||||
|
tempTestDir, DefaultTempBackupFileName,
|
||||||
|
),
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, testCase := range testCases {
|
||||||
|
// Ensure that all created files are removed at the end of the
|
||||||
|
// test case.
|
||||||
|
defer os.Remove(testCase.fileName)
|
||||||
|
defer os.Remove(testCase.tempFileName)
|
||||||
|
|
||||||
|
backupFile := NewMultiFile(testCase.fileName)
|
||||||
|
|
||||||
|
// To start with, we'll make a random byte slice that'll pose
|
||||||
|
// as our packed multi backup.
|
||||||
|
newPackedMulti, err := makeFakePackedMulti()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to make test backup: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the old temporary file is meant to exist, then we'll
|
||||||
|
// create it now as an empty file.
|
||||||
|
if testCase.oldTempExists {
|
||||||
|
_, err := os.Create(testCase.tempFileName)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(roasbeef): mock out fs calls?
|
||||||
|
}
|
||||||
|
|
||||||
|
// With our backup created, we'll now attempt to swap out this
|
||||||
|
// backup, for the old one.
|
||||||
|
err = backupFile.UpdateAndSwap(PackedMulti(newPackedMulti))
|
||||||
|
switch {
|
||||||
|
// If this is a valid test case, and we failed, then we'll
|
||||||
|
// return an error.
|
||||||
|
case err != nil && testCase.valid:
|
||||||
|
t.Fatalf("#%v, unable to swap file: %v", i, err)
|
||||||
|
|
||||||
|
// If this is an invalid test case, and we passed it, then
|
||||||
|
// we'll return an error.
|
||||||
|
case err == nil && !testCase.valid:
|
||||||
|
t.Fatalf("#%v file swap should have failed: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !testCase.valid {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we read out the file on disk, then it should match
|
||||||
|
// exactly what we wrote. The temp backup file should also be
|
||||||
|
// gone.
|
||||||
|
assertBackupMatches(t, testCase.fileName, newPackedMulti)
|
||||||
|
assertFileDeleted(t, testCase.tempFileName)
|
||||||
|
|
||||||
|
// Now that we know this is a valid test case, we'll make a new
|
||||||
|
// packed multi to swap out this current one.
|
||||||
|
newPackedMulti2, err := makeFakePackedMulti()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to make test backup: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We'll then attempt to swap the old version for this new one.
|
||||||
|
err = backupFile.UpdateAndSwap(PackedMulti(newPackedMulti2))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to swap file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Once again, the file written on disk should have been
|
||||||
|
// properly swapped out with the new instance.
|
||||||
|
assertBackupMatches(t, testCase.fileName, newPackedMulti2)
|
||||||
|
|
||||||
|
// Additionally, we shouldn't be able to find the temp backup
|
||||||
|
// file on disk, as it should be deleted each time.
|
||||||
|
assertFileDeleted(t, testCase.tempFileName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertMultiEqual(t *testing.T, a, b *Multi) {
|
||||||
|
|
||||||
|
if len(a.StaticBackups) != len(b.StaticBackups) {
|
||||||
|
t.Fatalf("expected %v backups, got %v", len(a.StaticBackups),
|
||||||
|
len(b.StaticBackups))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(a.StaticBackups); i++ {
|
||||||
|
assertSingleEqual(t, a.StaticBackups[i], b.StaticBackups[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractMulti tests that given a valid packed multi file on disk, we're
|
||||||
|
// able to read it multiple times repeatedly.
|
||||||
|
func TestExtractMulti(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
// First, as prep, we'll create a single chan backup, then pack that
|
||||||
|
// fully into a multi backup.
|
||||||
|
channel, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to gen chan: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
singleBackup := NewSingle(channel, nil)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
unpackedMulti := Multi{
|
||||||
|
StaticBackups: []Single{singleBackup},
|
||||||
|
}
|
||||||
|
err = unpackedMulti.PackToWriter(&b, keyRing)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to pack to writer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
packedMulti := PackedMulti(b.Bytes())
|
||||||
|
|
||||||
|
// Finally, we'll make a new temporary file, then write out the packed
|
||||||
|
// multi directly to to it.
|
||||||
|
tempFile, err := ioutil.TempFile("", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create temp file: %v", err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tempFile.Name())
|
||||||
|
|
||||||
|
_, err = tempFile.Write(packedMulti)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to write temp file: %v", err)
|
||||||
|
}
|
||||||
|
if err := tempFile.Sync(); err != nil {
|
||||||
|
t.Fatalf("unable to sync temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
fileName string
|
||||||
|
pass bool
|
||||||
|
}{
|
||||||
|
// Main file not read, file name not present.
|
||||||
|
{
|
||||||
|
fileName: "",
|
||||||
|
pass: false,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Main file not read, file name is there, but file doesn't
|
||||||
|
// exist.
|
||||||
|
{
|
||||||
|
fileName: "kek",
|
||||||
|
pass: false,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Main file not read, should be able to read multiple times.
|
||||||
|
{
|
||||||
|
fileName: tempFile.Name(),
|
||||||
|
pass: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, testCase := range testCases {
|
||||||
|
// First, we'll make our backup file with the specified name.
|
||||||
|
backupFile := NewMultiFile(testCase.fileName)
|
||||||
|
|
||||||
|
// With our file made, we'll now attempt to read out the
|
||||||
|
// multi-file.
|
||||||
|
freshUnpackedMulti, err := backupFile.ExtractMulti(keyRing)
|
||||||
|
switch {
|
||||||
|
// If this is a valid test case, and we failed, then we'll
|
||||||
|
// return an error.
|
||||||
|
case err != nil && testCase.pass:
|
||||||
|
t.Fatalf("#%v, unable to extract file: %v", i, err)
|
||||||
|
|
||||||
|
// If this is an invalid test case, and we passed it, then
|
||||||
|
// we'll return an error.
|
||||||
|
case err == nil && !testCase.pass:
|
||||||
|
t.Fatalf("#%v file extraction should have "+
|
||||||
|
"failed: %v", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !testCase.pass {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// We'll now ensure that the unpacked multi we read is
|
||||||
|
// identical to the one we wrote out above.
|
||||||
|
assertMultiEqual(t, &unpackedMulti, freshUnpackedMulti)
|
||||||
|
|
||||||
|
// We should also be able to read the file again, as we have an
|
||||||
|
// existing handle to it.
|
||||||
|
freshUnpackedMulti, err = backupFile.ExtractMulti(keyRing)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to unpack multi: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertMultiEqual(t, &unpackedMulti, freshUnpackedMulti)
|
||||||
|
}
|
||||||
|
}
|
140
chanbackup/crypto.go
Normal file
140
chanbackup/crypto.go
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO(roasbeef): interface in front of?
|
||||||
|
|
||||||
|
// baseEncryptionKeyLoc is the KeyLocator that we'll use to derive the base
|
||||||
|
// encryption key used for encrypting all static channel backups. We use this
|
||||||
|
// to then derive the actual key that we'll use for encryption. We do this
|
||||||
|
// rather than using the raw key, as we assume that we can't obtain the raw
|
||||||
|
// keys, and we don't want to require that the HSM know our target cipher for
|
||||||
|
// encryption.
|
||||||
|
//
|
||||||
|
// TODO(roasbeef): possibly unique encrypt?
|
||||||
|
var baseEncryptionKeyLoc = keychain.KeyLocator{
|
||||||
|
Family: keychain.KeyFamilyStaticBackup,
|
||||||
|
Index: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// genEncryptionKey derives the key that we'll use to encrypt all of our static
|
||||||
|
// channel backups. The key itself, is the sha2 of a base key that we get from
|
||||||
|
// the keyring. We derive the key this way as we don't force the HSM (or any
|
||||||
|
// future abstractions) to be able to derive and know of the cipher that we'll
|
||||||
|
// use within our protocol.
|
||||||
|
func genEncryptionKey(keyRing keychain.KeyRing) ([]byte, error) {
|
||||||
|
// key = SHA256(baseKey)
|
||||||
|
baseKey, err := keyRing.DeriveKey(
|
||||||
|
baseEncryptionKeyLoc,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptionKey := sha256.Sum256(
|
||||||
|
baseKey.PubKey.SerializeCompressed(),
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO(roasbeef): throw back in ECDH?
|
||||||
|
|
||||||
|
return encryptionKey[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encryptPayloadToWriter attempts to write the set of bytes contained within
|
||||||
|
// the passed byes.Buffer into the passed io.Writer in an encrypted form. We
|
||||||
|
// use a 24-byte chachapoly AEAD instance with a randomized nonce that's
|
||||||
|
// pre-pended to the final payload and used as associated data in the AEAD. We
|
||||||
|
// use the passed keyRing to generate the encryption key, see genEncryptionKey
|
||||||
|
// for further details.
|
||||||
|
func encryptPayloadToWriter(payload bytes.Buffer, w io.Writer,
|
||||||
|
keyRing keychain.KeyRing) error {
|
||||||
|
|
||||||
|
// First, we'll derive the key that we'll use to encrypt the payload
|
||||||
|
// for safe storage without giving away the details of any of our
|
||||||
|
// channels. The final operation is:
|
||||||
|
//
|
||||||
|
// key = SHA256(baseKey)
|
||||||
|
encryptionKey, err := genEncryptionKey(keyRing)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Before encryption, we'll initialize our cipher with the target
|
||||||
|
// encryption key, and also read out our random 24-byte nonce we use
|
||||||
|
// for encryption. Note that we use NewX, not New, as the latter
|
||||||
|
// version requires a 12-byte nonce, not a 24-byte nonce.
|
||||||
|
cipher, err := chacha20poly1305.NewX(encryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var nonce [chacha20poly1305.NonceSizeX]byte
|
||||||
|
if _, err := rand.Read(nonce[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, we encrypted the final payload, and write out our
|
||||||
|
// ciphertext with nonce pre-pended.
|
||||||
|
ciphertext := cipher.Seal(nil, nonce[:], payload.Bytes(), nonce[:])
|
||||||
|
|
||||||
|
if _, err := w.Write(nonce[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := w.Write(ciphertext); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decryptPayloadFromReader attempts to decrypt the encrypted bytes within the
|
||||||
|
// passed io.Reader instance using the key derived from the passed keyRing. For
|
||||||
|
// further details regarding the key derivation protocol, see the
|
||||||
|
// genEncryptionKey method.
|
||||||
|
func decryptPayloadFromReader(payload io.Reader,
|
||||||
|
keyRing keychain.KeyRing) ([]byte, error) {
|
||||||
|
|
||||||
|
// First, we'll re-generate the encryption key that we use for all the
|
||||||
|
// SCBs.
|
||||||
|
encryptionKey, err := genEncryptionKey(keyRing)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, we'll read out the entire blob as we need to isolate the nonce
|
||||||
|
// from the rest of the ciphertext.
|
||||||
|
packedBackup, err := ioutil.ReadAll(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(packedBackup) < chacha20poly1305.NonceSizeX {
|
||||||
|
return nil, fmt.Errorf("payload size too small, must be at "+
|
||||||
|
"least %v bytes", chacha20poly1305.NonceSizeX)
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce := packedBackup[:chacha20poly1305.NonceSizeX]
|
||||||
|
ciphertext := packedBackup[chacha20poly1305.NonceSizeX:]
|
||||||
|
|
||||||
|
// Now that we have the cipher text and the nonce separated, we can go
|
||||||
|
// ahead and decrypt the final blob so we can properly serialized the
|
||||||
|
// SCB.
|
||||||
|
cipher, err := chacha20poly1305.NewX(encryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
plaintext, err := cipher.Open(nil, nonce, ciphertext, nonce)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return plaintext, nil
|
||||||
|
}
|
156
chanbackup/crypto_test.go
Normal file
156
chanbackup/crypto_test.go
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
testWalletPrivKey = []byte{
|
||||||
|
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
|
||||||
|
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
|
||||||
|
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
|
||||||
|
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockKeyRing struct {
|
||||||
|
fail bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockKeyRing) DeriveNextKey(keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
|
||||||
|
return keychain.KeyDescriptor{}, nil
|
||||||
|
}
|
||||||
|
func (m *mockKeyRing) DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) {
|
||||||
|
if m.fail {
|
||||||
|
return keychain.KeyDescriptor{}, fmt.Errorf("fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, pub := btcec.PrivKeyFromBytes(btcec.S256(), testWalletPrivKey)
|
||||||
|
return keychain.KeyDescriptor{
|
||||||
|
PubKey: pub,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEncryptDecryptPayload tests that given a static key, we're able to
|
||||||
|
// properly decrypt and encrypted payload. We also test that we'll reject a
|
||||||
|
// ciphertext that has been modified.
|
||||||
|
func TestEncryptDecryptPayload(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
payloadCases := []struct {
|
||||||
|
// plaintext is the string that we'll be encrypting.
|
||||||
|
plaintext []byte
|
||||||
|
|
||||||
|
// mutator allows a test case to modify the ciphertext before
|
||||||
|
// we attempt to decrypt it.
|
||||||
|
mutator func(*[]byte)
|
||||||
|
|
||||||
|
// valid indicates if this test should pass or fail.
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
// Proper payload, should decrypt.
|
||||||
|
{
|
||||||
|
plaintext: []byte("payload test plain text"),
|
||||||
|
mutator: nil,
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Mutator modifies cipher text, shouldn't decrypt.
|
||||||
|
{
|
||||||
|
plaintext: []byte("payload test plain text"),
|
||||||
|
mutator: func(p *[]byte) {
|
||||||
|
// Flip a byte in the payload to render it invalid.
|
||||||
|
(*p)[0] ^= 1
|
||||||
|
},
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Cipher text is too small, shouldn't decrypt.
|
||||||
|
{
|
||||||
|
plaintext: []byte("payload test plain text"),
|
||||||
|
mutator: func(p *[]byte) {
|
||||||
|
// Modify the cipher text to be zero length.
|
||||||
|
*p = []byte{}
|
||||||
|
},
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
for i, payloadCase := range payloadCases {
|
||||||
|
var cipherBuffer bytes.Buffer
|
||||||
|
|
||||||
|
// First, we'll encrypt the passed payload with our scheme.
|
||||||
|
payloadReader := bytes.NewBuffer(payloadCase.plaintext)
|
||||||
|
err := encryptPayloadToWriter(
|
||||||
|
*payloadReader, &cipherBuffer, keyRing,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable encrypt paylaod: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a mutator, then we'll wrong the mutator over the
|
||||||
|
// cipher text, then reset the main buffer and re-write the new
|
||||||
|
// cipher text.
|
||||||
|
if payloadCase.mutator != nil {
|
||||||
|
cipherText := cipherBuffer.Bytes()
|
||||||
|
|
||||||
|
payloadCase.mutator(&cipherText)
|
||||||
|
|
||||||
|
cipherBuffer.Reset()
|
||||||
|
cipherBuffer.Write(cipherText)
|
||||||
|
}
|
||||||
|
|
||||||
|
plaintext, err := decryptPayloadFromReader(&cipherBuffer, keyRing)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
// If this was meant to be a valid decryption, but we failed,
|
||||||
|
// then we'll return an error.
|
||||||
|
case err != nil && payloadCase.valid:
|
||||||
|
t.Fatalf("unable to decrypt valid payload case %v", i)
|
||||||
|
|
||||||
|
// If this was meant to be an invalid decryption, and we didn't
|
||||||
|
// fail, then we'll return an error.
|
||||||
|
case err == nil && !payloadCase.valid:
|
||||||
|
t.Fatalf("payload was invalid yet was able to decrypt")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only if this case was mean to be valid will we ensure the
|
||||||
|
// resulting decrypted plaintext matches the original input.
|
||||||
|
if payloadCase.valid &&
|
||||||
|
!bytes.Equal(plaintext, payloadCase.plaintext) {
|
||||||
|
t.Fatalf("#%v: expected %v, got %v: ", i,
|
||||||
|
payloadCase.plaintext, plaintext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInvalidKeyEncryption tests that encryption fails if we're unable to
|
||||||
|
// obtain a valid key.
|
||||||
|
func TestInvalidKeyEncryption(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := encryptPayloadToWriter(b, &b, &mockKeyRing{true})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error due to fail key gen")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInvalidKeyDecrytion tests that decryption fails if we're unable to
|
||||||
|
// obtain a valid key.
|
||||||
|
func TestInvalidKeyDecrytion(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
_, err := decryptPayloadFromReader(&b, &mockKeyRing{true})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error due to fail key gen")
|
||||||
|
}
|
||||||
|
}
|
45
chanbackup/log.go
Normal file
45
chanbackup/log.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/btcsuite/btclog"
|
||||||
|
"github.com/lightningnetwork/lnd/build"
|
||||||
|
)
|
||||||
|
|
||||||
|
// log is a logger that is initialized with no output filters. This
|
||||||
|
// means the package will not perform any logging by default until the caller
|
||||||
|
// requests it.
|
||||||
|
var log btclog.Logger
|
||||||
|
|
||||||
|
// The default amount of logging is none.
|
||||||
|
func init() {
|
||||||
|
UseLogger(build.NewSubLogger("CHBU", nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableLog disables all library log output. Logging output is disabled
|
||||||
|
// by default until UseLogger is called.
|
||||||
|
func DisableLog() {
|
||||||
|
UseLogger(btclog.Disabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseLogger uses a specified Logger to output package logging info.
|
||||||
|
// This should be used in preference to SetLogWriter if the caller is also
|
||||||
|
// using btclog.
|
||||||
|
func UseLogger(logger btclog.Logger) {
|
||||||
|
log = logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// logClosure is used to provide a closure over expensive logging operations so
|
||||||
|
// don't have to be performed when the logging level doesn't warrant it.
|
||||||
|
type logClosure func() string
|
||||||
|
|
||||||
|
// String invokes the underlying function and returns the result.
|
||||||
|
func (c logClosure) String() string {
|
||||||
|
return c()
|
||||||
|
}
|
||||||
|
|
||||||
|
// newLogClosure returns a new closure over a function that returns a string
|
||||||
|
// which itself provides a Stringer interface so that it can be used with the
|
||||||
|
// logging system.
|
||||||
|
func newLogClosure(c func() string) logClosure {
|
||||||
|
return logClosure(c)
|
||||||
|
}
|
176
chanbackup/multi.go
Normal file
176
chanbackup/multi.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MultiBackupVersion denotes the version of the multi channel static channel
|
||||||
|
// backup. Based on this version, we know how to encode/decode packed/unpacked
|
||||||
|
// versions of multi backups.
|
||||||
|
type MultiBackupVersion byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultMultiVersion is the default version of the multi channel
|
||||||
|
// backup. The serialized format for this version is simply: version ||
|
||||||
|
// numBackups || SCBs...
|
||||||
|
DefaultMultiVersion = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
// Multi is a form of static channel backup that is amenable to being
|
||||||
|
// serialized in a single file. Rather than a series of ciphertexts, a
|
||||||
|
// multi-chan backup is a single ciphertext of all static channel backups
|
||||||
|
// concatenated. This form factor gives users a single blob that they can use
|
||||||
|
// to safely copy/obtain at anytime to backup their channels.
|
||||||
|
type Multi struct {
|
||||||
|
// Version is the version that should be observed when attempting to
|
||||||
|
// pack the multi backup.
|
||||||
|
Version MultiBackupVersion
|
||||||
|
|
||||||
|
// StaticBackups is the set of single channel backups that this multi
|
||||||
|
// backup is comprised of.
|
||||||
|
StaticBackups []Single
|
||||||
|
}
|
||||||
|
|
||||||
|
// PackToWriter packs (encrypts+serializes) the target set of static channel
|
||||||
|
// backups into a single AEAD ciphertext into the passed io.Writer. This is the
|
||||||
|
// opposite of UnpackFromReader. The plaintext form of a multi-chan backup is
|
||||||
|
// the following: a 4 byte integer denoting the number of serialized static
|
||||||
|
// channel backups serialized, a series of serialized static channel backups
|
||||||
|
// concatenated. To pack this payload, we then apply our chacha20 AEAD to the
|
||||||
|
// entire payload, using the 24-byte nonce as associated data.
|
||||||
|
func (m Multi) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error {
|
||||||
|
// The only version that we know how to pack atm is version 0. Attempts
|
||||||
|
// to pack any other version will result in an error.
|
||||||
|
switch m.Version {
|
||||||
|
case DefaultMultiVersion:
|
||||||
|
break
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unable to pack unknown multi-version "+
|
||||||
|
"of %v", m.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
var multiBackupBuffer bytes.Buffer
|
||||||
|
|
||||||
|
// First, we'll write out the version of this multi channel baackup.
|
||||||
|
err := lnwire.WriteElements(&multiBackupBuffer, byte(m.Version))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we've written out the version of this multi-pack format,
|
||||||
|
// we'll now write the total number of backups to expect after this
|
||||||
|
// point.
|
||||||
|
numBackups := uint32(len(m.StaticBackups))
|
||||||
|
err = lnwire.WriteElements(&multiBackupBuffer, numBackups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, we'll serialize the raw plaintext version of each of the
|
||||||
|
// backup into the intermediate buffer.
|
||||||
|
for _, chanBackup := range m.StaticBackups {
|
||||||
|
err := chanBackup.Serialize(&multiBackupBuffer)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to serialize backup "+
|
||||||
|
"for %v: %v", chanBackup.FundingOutpoint, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// With the plaintext multi backup assembled, we'll now encrypt it
|
||||||
|
// directly to the passed writer.
|
||||||
|
return encryptPayloadToWriter(multiBackupBuffer, w, keyRing)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnpackFromReader attempts to unpack (decrypt+deserialize) a packed
|
||||||
|
// multi-chan backup form the passed io.Reader. If we're unable to decrypt the
|
||||||
|
// any portion of the multi-chan backup, an error will be returned.
|
||||||
|
func (m *Multi) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error {
|
||||||
|
// We'll attempt to read the entire packed backup, and also decrypt it
|
||||||
|
// using the passed key ring which is expected to be able to derive the
|
||||||
|
// encryption keys.
|
||||||
|
plaintextBackup, err := decryptPayloadFromReader(r, keyRing)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
backupReader := bytes.NewReader(plaintextBackup)
|
||||||
|
|
||||||
|
// Now that we've decrypted the payload successfully, we can parse out
|
||||||
|
// each of the individual static channel backups.
|
||||||
|
|
||||||
|
// First, we'll need to read the version of this multi-back up so we
|
||||||
|
// can know how to unpack each of the individual SCB's.
|
||||||
|
var multiVersion byte
|
||||||
|
err = lnwire.ReadElements(backupReader, &multiVersion)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Version = MultiBackupVersion(multiVersion)
|
||||||
|
switch m.Version {
|
||||||
|
|
||||||
|
// The default version is simply a set of serialized SCB's with the
|
||||||
|
// number of total SCB's prepended to the front of the byte slice.
|
||||||
|
case DefaultMultiVersion:
|
||||||
|
// First, we'll need to read out the total number of backups
|
||||||
|
// that've been serialized into this multi-chan backup. Each
|
||||||
|
// backup is the same size, so we can continue until we've
|
||||||
|
// parsed out everything.
|
||||||
|
var numBackups uint32
|
||||||
|
err = lnwire.ReadElements(backupReader, &numBackups)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We'll continue to parse out each backup until we've read all
|
||||||
|
// that was indicated from the length prefix.
|
||||||
|
for ; numBackups != 0; numBackups-- {
|
||||||
|
// Attempt to parse out the net static channel backup,
|
||||||
|
// if it's been malformed, then we'll return with an
|
||||||
|
// error
|
||||||
|
var chanBackup Single
|
||||||
|
err := chanBackup.Deserialize(backupReader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect the next valid chan backup into the main
|
||||||
|
// multi backup slice.
|
||||||
|
m.StaticBackups = append(m.StaticBackups, chanBackup)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unable to unpack unknown multi-version "+
|
||||||
|
"of %v", multiVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(roasbeef): new key ring interface?
|
||||||
|
// * just returns key given params?
|
||||||
|
|
||||||
|
// PackedMulti represents a raw fully packed (serialized+encrypted)
|
||||||
|
// multi-channel static channel backup.
|
||||||
|
type PackedMulti []byte
|
||||||
|
|
||||||
|
// Unpack attempts to unpack (decrypt+desrialize) the target packed
|
||||||
|
// multi-channel back up. If we're unable to fully unpack this back, then an
|
||||||
|
// error will be returned.
|
||||||
|
func (p *PackedMulti) Unpack(keyRing keychain.KeyRing) (*Multi, error) {
|
||||||
|
var m Multi
|
||||||
|
|
||||||
|
packedReader := bytes.NewReader(*p)
|
||||||
|
if err := m.UnpackFromReader(packedReader, keyRing); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(roasbsef): fuzz parsing
|
159
chanbackup/multi_test.go
Normal file
159
chanbackup/multi_test.go
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestMultiPackUnpack...
|
||||||
|
func TestMultiPackUnpack(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var multi Multi
|
||||||
|
numSingles := 10
|
||||||
|
originalSingles := make([]Single, 0, numSingles)
|
||||||
|
for i := 0; i < numSingles; i++ {
|
||||||
|
channel, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to gen channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
single := NewSingle(channel, []net.Addr{addr1, addr2})
|
||||||
|
|
||||||
|
originalSingles = append(originalSingles, single)
|
||||||
|
multi.StaticBackups = append(multi.StaticBackups, single)
|
||||||
|
}
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
versionTestCases := []struct {
|
||||||
|
// version is the pack/unpack version that we should use to
|
||||||
|
// decode/encode the final SCB.
|
||||||
|
version MultiBackupVersion
|
||||||
|
|
||||||
|
// valid tests us if this test case should pass or not.
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
// The default version, should pack/unpack with no problem.
|
||||||
|
{
|
||||||
|
version: DefaultSingleVersion,
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// A non-default version, atm this should result in a failure.
|
||||||
|
{
|
||||||
|
version: 99,
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, versionCase := range versionTestCases {
|
||||||
|
multi.Version = versionCase.version
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := multi.PackToWriter(&b, keyRing)
|
||||||
|
switch {
|
||||||
|
// If this is a valid test case, and we failed, then we'll
|
||||||
|
// return an error.
|
||||||
|
case err != nil && versionCase.valid:
|
||||||
|
t.Fatalf("#%v, unable to pack multi: %v", i, err)
|
||||||
|
|
||||||
|
// If this is an invalid test case, and we passed it, then
|
||||||
|
// we'll return an error.
|
||||||
|
case err == nil && !versionCase.valid:
|
||||||
|
t.Fatalf("#%v got nil error for invalid pack: %v",
|
||||||
|
i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is a valid test case, then we'll continue to ensure
|
||||||
|
// we can unpack it, and also that if we mutate the packed
|
||||||
|
// version, then we trigger an error.
|
||||||
|
if versionCase.valid {
|
||||||
|
var unpackedMulti Multi
|
||||||
|
err = unpackedMulti.UnpackFromReader(&b, keyRing)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("#%v unable to unpack multi: %v",
|
||||||
|
i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First, we'll ensure that the unpacked version of the
|
||||||
|
// packed multi is the same as the original set.
|
||||||
|
if len(originalSingles) !=
|
||||||
|
len(unpackedMulti.StaticBackups) {
|
||||||
|
t.Fatalf("expected %v singles, got %v",
|
||||||
|
len(originalSingles),
|
||||||
|
len(unpackedMulti.StaticBackups))
|
||||||
|
}
|
||||||
|
for i := 0; i < numSingles; i++ {
|
||||||
|
assertSingleEqual(
|
||||||
|
t, originalSingles[i],
|
||||||
|
unpackedMulti.StaticBackups[i],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, we'll make a fake packed multi, it'll have an
|
||||||
|
// unknown version relative to what's implemented atm.
|
||||||
|
var fakePackedMulti bytes.Buffer
|
||||||
|
fakeRawMulti := bytes.NewBuffer(
|
||||||
|
bytes.Repeat([]byte{99}, 20),
|
||||||
|
)
|
||||||
|
err := encryptPayloadToWriter(
|
||||||
|
*fakeRawMulti, &fakePackedMulti, keyRing,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to pack fake multi; %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should reject this fake multi as it contains an
|
||||||
|
// unknown version.
|
||||||
|
err = unpackedMulti.UnpackFromReader(
|
||||||
|
&fakePackedMulti, keyRing,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("#%v unpack with unknown version "+
|
||||||
|
"should have failed", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPackedMultiUnpack tests that we're able to properly unpack a typed
|
||||||
|
// packed multi.
|
||||||
|
func TestPackedMultiUnpack(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
// First, we'll make a new unpacked multi with a random channel.
|
||||||
|
testChannel, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to gen random channel: %v", err)
|
||||||
|
}
|
||||||
|
var multi Multi
|
||||||
|
multi.StaticBackups = append(
|
||||||
|
multi.StaticBackups, NewSingle(testChannel, nil),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Now that we have our multi, we'll pack it into a new buffer.
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := multi.PackToWriter(&b, keyRing); err != nil {
|
||||||
|
t.Fatalf("unable to pack multi: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should be able to properly unpack this typed packed multi.
|
||||||
|
packedMulti := PackedMulti(b.Bytes())
|
||||||
|
unpackedMulti, err := packedMulti.Unpack(keyRing)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to unpack multi: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, the versions should match, and the unpacked singles also
|
||||||
|
// identical.
|
||||||
|
if multi.Version != unpackedMulti.Version {
|
||||||
|
t.Fatalf("version mismatch: expected %v got %v",
|
||||||
|
multi.Version, unpackedMulti.Version)
|
||||||
|
}
|
||||||
|
assertSingleEqual(
|
||||||
|
t, multi.StaticBackups[0], unpackedMulti.StaticBackups[0],
|
||||||
|
)
|
||||||
|
}
|
247
chanbackup/pubsub.go
Normal file
247
chanbackup/pubsub.go
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/wire"
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Swapper is an interface that allows the chanbackup.SubSwapper to update the
|
||||||
|
// main multi backup location once it learns of new channels or that prior
|
||||||
|
// channels have been closed.
|
||||||
|
type Swapper interface {
|
||||||
|
// UpdateAndSwap attempts to atomically update the main multi back up
|
||||||
|
// file location with the new fully packed multi-channel backup.
|
||||||
|
UpdateAndSwap(newBackup PackedMulti) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChannelWithAddrs bundles an open channel along with all the addresses for
|
||||||
|
// the channel peer.
|
||||||
|
//
|
||||||
|
// TODO(roasbeef): use channel shell instead?
|
||||||
|
type ChannelWithAddrs struct {
|
||||||
|
*channeldb.OpenChannel
|
||||||
|
|
||||||
|
// Addrs is the set of addresses that we can use to reach the target
|
||||||
|
// peer.
|
||||||
|
Addrs []net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChannelEvent packages a new update of new channels since subscription, and
|
||||||
|
// channels that have been opened since prior channel event.
|
||||||
|
type ChannelEvent struct {
|
||||||
|
// ClosedChans are the set of channels that have been closed since the
|
||||||
|
// last event.
|
||||||
|
ClosedChans []wire.OutPoint
|
||||||
|
|
||||||
|
// NewChans is the set of channels that have been opened since the last
|
||||||
|
// event.
|
||||||
|
NewChans []ChannelWithAddrs
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChannelSubscription represents an intent to be notified of any updates to
|
||||||
|
// the primary channel state.
|
||||||
|
type ChannelSubscription struct {
|
||||||
|
// ChanUpdates is a read-only channel that will be sent upon once the
|
||||||
|
// primary channel state is updated.
|
||||||
|
ChanUpdates <-chan ChannelEvent
|
||||||
|
|
||||||
|
// Cancel is a closure that allows the caller to cancel their
|
||||||
|
// subscription and free up any resources allocated.
|
||||||
|
Cancel func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChannelNotifier represents a system that allows the chanbackup.SubSwapper to
|
||||||
|
// be notified of any changes to the primary channel state.
|
||||||
|
type ChannelNotifier interface {
|
||||||
|
// SubscribeChans requests a new channel subscription relative to the
|
||||||
|
// initial set of known channels. We use the knownChans as a
|
||||||
|
// synchronization point to ensure that the chanbackup.SubSwapper does
|
||||||
|
// not miss any channel open or close events in the period between when
|
||||||
|
// it's created, and when it requests the channel subscription.
|
||||||
|
SubscribeChans(map[wire.OutPoint]struct{}) (*ChannelSubscription, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubSwapper subscribes to new updates to the open channel state, and then
|
||||||
|
// swaps out the on-disk channel backup state in response. This sub-system
|
||||||
|
// that will ensure that the multi chan backup file on disk will always be
|
||||||
|
// updated with the latest channel back up state. We'll receive new
|
||||||
|
// opened/closed channels from the ChannelNotifier, then use the Swapper to
|
||||||
|
// update the file state on disk with the new set of open channels. This can
|
||||||
|
// be used to implement a system that always keeps the multi-chan backup file
|
||||||
|
// on disk in a consistent state for safety purposes.
|
||||||
|
//
|
||||||
|
// TODO(roasbeef): better name lol
|
||||||
|
type SubSwapper struct {
|
||||||
|
started uint32
|
||||||
|
stopped uint32
|
||||||
|
|
||||||
|
// backupState are the set of SCBs for all open channels we know of.
|
||||||
|
backupState map[wire.OutPoint]Single
|
||||||
|
|
||||||
|
// chanEvents is an active subscription to receive new channel state
|
||||||
|
// over.
|
||||||
|
chanEvents *ChannelSubscription
|
||||||
|
|
||||||
|
// keyRing is the main key ring that will allow us to pack the new
|
||||||
|
// multi backup.
|
||||||
|
keyRing keychain.KeyRing
|
||||||
|
|
||||||
|
Swapper
|
||||||
|
|
||||||
|
quit chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSubSwapper creates a new instance of the SubSwapper given the starting
|
||||||
|
// set of channels, and the required interfaces to be notified of new channel
|
||||||
|
// updates, pack a multi backup, and swap the current best backup from its
|
||||||
|
// storage location.
|
||||||
|
func NewSubSwapper(startingChans []Single, chanNotifier ChannelNotifier,
|
||||||
|
keyRing keychain.KeyRing, backupSwapper Swapper) (*SubSwapper, error) {
|
||||||
|
|
||||||
|
// First, we'll subscribe to the latest set of channel updates given
|
||||||
|
// the set of channels we already know of.
|
||||||
|
knownChans := make(map[wire.OutPoint]struct{})
|
||||||
|
for _, chanBackup := range startingChans {
|
||||||
|
knownChans[chanBackup.FundingOutpoint] = struct{}{}
|
||||||
|
}
|
||||||
|
chanEvents, err := chanNotifier.SubscribeChans(knownChans)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, we'll construct our own backup state so we can add/remove
|
||||||
|
// channels that have been opened and closed.
|
||||||
|
backupState := make(map[wire.OutPoint]Single)
|
||||||
|
for _, chanBackup := range startingChans {
|
||||||
|
backupState[chanBackup.FundingOutpoint] = chanBackup
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SubSwapper{
|
||||||
|
backupState: backupState,
|
||||||
|
chanEvents: chanEvents,
|
||||||
|
keyRing: keyRing,
|
||||||
|
Swapper: backupSwapper,
|
||||||
|
quit: make(chan struct{}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the chanbackup.SubSwapper.
|
||||||
|
func (s *SubSwapper) Start() error {
|
||||||
|
if !atomic.CompareAndSwapUint32(&s.started, 0, 1) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Starting chanbackup.SubSwapper")
|
||||||
|
|
||||||
|
s.wg.Add(1)
|
||||||
|
go s.backupUpdater()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop signals the SubSwapper to being a graceful shutdown.
|
||||||
|
func (s *SubSwapper) Stop() error {
|
||||||
|
if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Stopping chanbackup.SubSwapper")
|
||||||
|
|
||||||
|
close(s.quit)
|
||||||
|
s.wg.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// backupFileUpdater is the primary goroutine of the SubSwapper which is
|
||||||
|
// responsible for listening for changes to the channel, and updating the
|
||||||
|
// persistent multi backup state with a new packed multi of the latest channel
|
||||||
|
// state.
|
||||||
|
func (s *SubSwapper) backupUpdater() {
|
||||||
|
// Ensure that once we exit, we'll cancel our active channel
|
||||||
|
// subscription.
|
||||||
|
defer s.chanEvents.Cancel()
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
log.Debugf("SubSwapper's backupUpdater is active!")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// The channel state has been modified! We'll evaluate all
|
||||||
|
// changes, and swap out the old packed multi with a new one
|
||||||
|
// with the latest channel state.
|
||||||
|
case chanUpdate := <-s.chanEvents.ChanUpdates:
|
||||||
|
oldStateSize := len(s.backupState)
|
||||||
|
|
||||||
|
// For all new open channels, we'll create a new SCB
|
||||||
|
// given the required information.
|
||||||
|
for _, newChan := range chanUpdate.NewChans {
|
||||||
|
log.Debugf("Adding chanenl %v to backup state",
|
||||||
|
newChan.FundingOutpoint)
|
||||||
|
|
||||||
|
s.backupState[newChan.FundingOutpoint] = NewSingle(
|
||||||
|
newChan.OpenChannel, newChan.Addrs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For all closed channels, we'll remove the prior
|
||||||
|
// backup state.
|
||||||
|
for _, closedChan := range chanUpdate.ClosedChans {
|
||||||
|
log.Debugf("Removing channel %v from backup "+
|
||||||
|
"state", newLogClosure(func() string {
|
||||||
|
return closedChan.String()
|
||||||
|
}))
|
||||||
|
|
||||||
|
delete(s.backupState, closedChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
newStateSize := len(s.backupState)
|
||||||
|
|
||||||
|
// With our updated channel state obtained, we'll
|
||||||
|
// create a new multi from our series of singles.
|
||||||
|
var newMulti Multi
|
||||||
|
for _, backup := range s.backupState {
|
||||||
|
newMulti.StaticBackups = append(
|
||||||
|
newMulti.StaticBackups, backup,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that our multi has been assembled, we'll attempt
|
||||||
|
// to pack (encrypt+encode) the new channel state to
|
||||||
|
// our target reader.
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := newMulti.PackToWriter(&b, s.keyRing)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to pack multi backup: %v",
|
||||||
|
err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Updating on-disk multi SCB backup: "+
|
||||||
|
"num_old_chans=%v, num_new_chans=%v",
|
||||||
|
oldStateSize, newStateSize)
|
||||||
|
|
||||||
|
// Finally, we'll swap out the old backup for this new
|
||||||
|
// one in a single atomic step.
|
||||||
|
err = s.Swapper.UpdateAndSwap(
|
||||||
|
PackedMulti(b.Bytes()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("unable to update multi "+
|
||||||
|
"backup: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exit at once if a quit signal is detected.
|
||||||
|
case <-s.quit:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
234
chanbackup/pubsub_test.go
Normal file
234
chanbackup/pubsub_test.go
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/wire"
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockSwapper struct {
|
||||||
|
fail bool
|
||||||
|
|
||||||
|
swaps chan PackedMulti
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockSwapper() *mockSwapper {
|
||||||
|
return &mockSwapper{
|
||||||
|
swaps: make(chan PackedMulti),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockSwapper) UpdateAndSwap(newBackup PackedMulti) error {
|
||||||
|
if m.fail {
|
||||||
|
return fmt.Errorf("fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.swaps <- newBackup
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockChannelNotifier struct {
|
||||||
|
fail bool
|
||||||
|
|
||||||
|
chanEvents chan ChannelEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockChannelNotifier() *mockChannelNotifier {
|
||||||
|
return &mockChannelNotifier{
|
||||||
|
chanEvents: make(chan ChannelEvent),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) (
|
||||||
|
*ChannelSubscription, error) {
|
||||||
|
|
||||||
|
if m.fail {
|
||||||
|
return nil, fmt.Errorf("fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ChannelSubscription{
|
||||||
|
ChanUpdates: m.chanEvents,
|
||||||
|
Cancel: func() {
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewSubSwapperSubscribeFail tests that if we're unable to obtain a
|
||||||
|
// channel subscription, then the entire sub-swapper will fail to start.
|
||||||
|
func TestNewSubSwapperSubscribeFail(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
var swapper mockSwapper
|
||||||
|
chanNotifier := mockChannelNotifier{
|
||||||
|
fail: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := NewSubSwapper(nil, &chanNotifier, keyRing, &swapper)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected fail due to lack of subscription")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertExpectedBackupSwap(t *testing.T, swapper *mockSwapper,
|
||||||
|
subSwapper *SubSwapper, keyRing keychain.KeyRing,
|
||||||
|
expectedChanSet map[wire.OutPoint]Single) {
|
||||||
|
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case newPackedMulti := <-swapper.swaps:
|
||||||
|
// If we unpack the new multi, then we should find all the old
|
||||||
|
// channels, and also the new channel included and any deleted
|
||||||
|
// channel omitted..
|
||||||
|
newMulti, err := newPackedMulti.Unpack(keyRing)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to unpack multi: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that once unpacked, the current backup has the
|
||||||
|
// expected number of Singles.
|
||||||
|
if len(newMulti.StaticBackups) != len(expectedChanSet) {
|
||||||
|
t.Fatalf("new backup wasn't included: expected %v "+
|
||||||
|
"backups have %v", len(expectedChanSet),
|
||||||
|
len(newMulti.StaticBackups))
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should also find all the old and new channels in this new
|
||||||
|
// backup.
|
||||||
|
for _, backup := range newMulti.StaticBackups {
|
||||||
|
_, ok := expectedChanSet[backup.FundingOutpoint]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("didn't find backup in original set: %v",
|
||||||
|
backup.FundingOutpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The internal state of the sub-swapper should also be one
|
||||||
|
// larger.
|
||||||
|
if !reflect.DeepEqual(expectedChanSet, subSwapper.backupState) {
|
||||||
|
t.Fatalf("backup set doesn't match: expected %v got %v",
|
||||||
|
spew.Sdump(expectedChanSet),
|
||||||
|
spew.Sdump(subSwapper.backupState))
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
t.Fatalf("update swapper didn't swap out multi")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSubSwapperIdempotentStartStop tests that calling the Start/Stop methods
|
||||||
|
// multiple time is permitted.
|
||||||
|
func TestSubSwapperIdempotentStartStop(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
var (
|
||||||
|
swapper mockSwapper
|
||||||
|
chanNotifier mockChannelNotifier
|
||||||
|
)
|
||||||
|
|
||||||
|
subSwapper, err := NewSubSwapper(nil, &chanNotifier, keyRing, &swapper)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to init subSwapper: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
subSwapper.Start()
|
||||||
|
subSwapper.Start()
|
||||||
|
|
||||||
|
subSwapper.Stop()
|
||||||
|
subSwapper.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSubSwapperUpdater tests that the SubSwapper will properly swap out
|
||||||
|
// new/old channels within the channel set, and notify the swapper to update
|
||||||
|
// the master multi file backup.
|
||||||
|
func TestSubSwapperUpdater(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
chanNotifier := newMockChannelNotifier()
|
||||||
|
swapper := newMockSwapper()
|
||||||
|
|
||||||
|
// First, we'll start out by creating a channels set for the initial
|
||||||
|
// set of channels known to the sub-swapper.
|
||||||
|
const numStartingChans = 3
|
||||||
|
initialChanSet := make([]Single, 0, numStartingChans)
|
||||||
|
backupSet := make(map[wire.OutPoint]Single)
|
||||||
|
for i := 0; i < numStartingChans; i++ {
|
||||||
|
channel, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to make test chan: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
single := NewSingle(channel, nil)
|
||||||
|
|
||||||
|
backupSet[channel.FundingOutpoint] = single
|
||||||
|
initialChanSet = append(initialChanSet, single)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With our channel set created, we'll make a fresh sub swapper
|
||||||
|
// instance to begin our test.
|
||||||
|
subSwapper, err := NewSubSwapper(
|
||||||
|
initialChanSet, chanNotifier, keyRing, swapper,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to make swapper: %v", err)
|
||||||
|
}
|
||||||
|
if err := subSwapper.Start(); err != nil {
|
||||||
|
t.Fatalf("unable to start sub swapper: %v", err)
|
||||||
|
}
|
||||||
|
defer subSwapper.Stop()
|
||||||
|
|
||||||
|
// Now that the sub-swapper is active, we'll notify to add a brand new
|
||||||
|
// channel to the channel state.
|
||||||
|
newChannel, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create new chan: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With the new channel created, we'll send a new update to the main
|
||||||
|
// goroutine telling it about this new channel.
|
||||||
|
select {
|
||||||
|
case chanNotifier.chanEvents <- ChannelEvent{
|
||||||
|
NewChans: []ChannelWithAddrs{
|
||||||
|
{
|
||||||
|
OpenChannel: newChannel,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}:
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
t.Fatalf("update swapper didn't read new channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
backupSet[newChannel.FundingOutpoint] = NewSingle(newChannel, nil)
|
||||||
|
|
||||||
|
// At this point, the sub-swapper should now have packed a new multi,
|
||||||
|
// and then sent it to the swapper so the back up can be updated.
|
||||||
|
assertExpectedBackupSwap(t, swapper, subSwapper, keyRing, backupSet)
|
||||||
|
|
||||||
|
// We'll now trigger an update to remove an existing channel.
|
||||||
|
chanToDelete := initialChanSet[0].FundingOutpoint
|
||||||
|
select {
|
||||||
|
case chanNotifier.chanEvents <- ChannelEvent{
|
||||||
|
ClosedChans: []wire.OutPoint{chanToDelete},
|
||||||
|
}:
|
||||||
|
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
t.Fatalf("update swapper didn't read new channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(backupSet, chanToDelete)
|
||||||
|
|
||||||
|
// Verify that the new set of backups, now has one less after the
|
||||||
|
// sub-swapper switches the new set with the old.
|
||||||
|
assertExpectedBackupSwap(t, swapper, subSwapper, keyRing, backupSet)
|
||||||
|
}
|
114
chanbackup/recover.go
Normal file
114
chanbackup/recover.go
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChannelRestorer is an interface that allows the Recover method to map the
|
||||||
|
// set of single channel backups into a set of "channel shells" and store these
|
||||||
|
// persistently on disk. The channel shell should contain all the information
|
||||||
|
// needed to execute the data loss recovery protocol once the channel peer is
|
||||||
|
// connected to.
|
||||||
|
type ChannelRestorer interface {
|
||||||
|
// RestoreChansFromSingles attempts to map the set of single channel
|
||||||
|
// backups to channel shells that will be stored persistently. Once
|
||||||
|
// these shells have been stored on disk, we'll be able to connect to
|
||||||
|
// the channel peer an execute the data loss recovery protocol.
|
||||||
|
RestoreChansFromSingles(...Single) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeerConnector is an interface that allows the Recover method to connect to
|
||||||
|
// the target node given the set of possible addresses.
|
||||||
|
type PeerConnector interface {
|
||||||
|
// ConnectPeer attempts to connect to the target node at the set of
|
||||||
|
// available addresses. Once this method returns with a non-nil error,
|
||||||
|
// the connector should attempt to persistently connect to the target
|
||||||
|
// peer in the background as a persistent attempt.
|
||||||
|
ConnectPeer(node *btcec.PublicKey, addrs []net.Addr) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recover attempts to recover the static channel state from a set of static
|
||||||
|
// channel backups. If successfully, the database will be populated with a
|
||||||
|
// series of "shell" channels. These "shell" channels cannot be used to operate
|
||||||
|
// the channel as normal, but instead are meant to be used to enter the data
|
||||||
|
// loss recovery phase, and recover the settled funds within
|
||||||
|
// the channel. In addition a LinkNode will be created for each new peer as
|
||||||
|
// well, in order to expose the addressing information required to locate to
|
||||||
|
// and connect to each peer in order to initiate the recovery protocol.
|
||||||
|
func Recover(backups []Single, restorer ChannelRestorer,
|
||||||
|
peerConnector PeerConnector) error {
|
||||||
|
|
||||||
|
for _, backup := range backups {
|
||||||
|
log.Infof("Restoring ChannelPoint(%v) to disk: ",
|
||||||
|
backup.FundingOutpoint)
|
||||||
|
|
||||||
|
err := restorer.RestoreChansFromSingles(backup)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Attempting to connect to node=%x (addrs=%v) to "+
|
||||||
|
"restore ChannelPoint(%v)",
|
||||||
|
backup.RemoteNodePub.SerializeCompressed(),
|
||||||
|
newLogClosure(func() string {
|
||||||
|
return spew.Sdump(backup.Addresses)
|
||||||
|
}), backup.FundingOutpoint)
|
||||||
|
|
||||||
|
err = peerConnector.ConnectPeer(
|
||||||
|
backup.RemoteNodePub, backup.Addresses,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(roasbeef): to handle case where node has changed addrs,
|
||||||
|
// need to subscribe to new updates for target node pub to
|
||||||
|
// attempt to connect to other addrs
|
||||||
|
//
|
||||||
|
// * just to to fresh w/ call to node addrs and de-dup?
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(roasbeef): more specific keychain interface?
|
||||||
|
|
||||||
|
// UnpackAndRecoverSingles is a one-shot method, that given a set of packed
|
||||||
|
// single channel backups, will restore the channel state to a channel shell,
|
||||||
|
// and also reach out to connect to any of the known node addresses for that
|
||||||
|
// channel. It is assumes that after this method exists, if a connection we
|
||||||
|
// able to be established, then then PeerConnector will continue to attempt to
|
||||||
|
// re-establish a persistent connection in the background.
|
||||||
|
func UnpackAndRecoverSingles(singles PackedSingles,
|
||||||
|
keyChain keychain.KeyRing, restorer ChannelRestorer,
|
||||||
|
peerConnector PeerConnector) error {
|
||||||
|
|
||||||
|
chanBackups, err := singles.Unpack(keyChain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Recover(chanBackups, restorer, peerConnector)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnpackAndRecoverMulti is a one-shot method, that given a set of packed
|
||||||
|
// multi-channel backups, will restore the channel states to channel shells,
|
||||||
|
// and also reach out to connect to any of the known node addresses for that
|
||||||
|
// channel. It is assumes that after this method exists, if a connection we
|
||||||
|
// able to be established, then then PeerConnector will continue to attempt to
|
||||||
|
// re-establish a persistent connection in the background.
|
||||||
|
func UnpackAndRecoverMulti(packedMulti PackedMulti,
|
||||||
|
keyChain keychain.KeyRing, restorer ChannelRestorer,
|
||||||
|
peerConnector PeerConnector) error {
|
||||||
|
|
||||||
|
chanBackups, err := packedMulti.Unpack(keyChain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Recover(chanBackups.StaticBackups, restorer, peerConnector)
|
||||||
|
}
|
232
chanbackup/recover_test.go
Normal file
232
chanbackup/recover_test.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockChannelRestorer struct {
|
||||||
|
fail bool
|
||||||
|
|
||||||
|
callCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockChannelRestorer) RestoreChansFromSingles(...Single) error {
|
||||||
|
if m.fail {
|
||||||
|
return fmt.Errorf("fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.callCount++
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockPeerConnector struct {
|
||||||
|
fail bool
|
||||||
|
|
||||||
|
callCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockPeerConnector) ConnectPeer(node *btcec.PublicKey,
|
||||||
|
addrs []net.Addr) error {
|
||||||
|
|
||||||
|
if m.fail {
|
||||||
|
return fmt.Errorf("fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.callCount++
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUnpackAndRecoverSingles tests that we're able to properly unpack and
|
||||||
|
// recover a set of packed singles.
|
||||||
|
func TestUnpackAndRecoverSingles(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
// First, we'll create a number of single chan backups that we'll
|
||||||
|
// shortly back to so we can begin our recovery attempt.
|
||||||
|
numSingles := 10
|
||||||
|
backups := make([]Single, 0, numSingles)
|
||||||
|
var packedBackups PackedSingles
|
||||||
|
for i := 0; i < numSingles; i++ {
|
||||||
|
channel, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable make channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
single := NewSingle(channel, nil)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := single.PackToWriter(&b, keyRing); err != nil {
|
||||||
|
t.Fatalf("unable to pack single: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
backups = append(backups, single)
|
||||||
|
packedBackups = append(packedBackups, b.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
chanRestorer := mockChannelRestorer{}
|
||||||
|
peerConnector := mockPeerConnector{}
|
||||||
|
|
||||||
|
// Now that we have our backups (packed and unpacked), we'll attempt to
|
||||||
|
// restore them all in a single batch.
|
||||||
|
|
||||||
|
// If we make the channel restore fail, then the entire method should
|
||||||
|
// as well
|
||||||
|
chanRestorer.fail = true
|
||||||
|
err := UnpackAndRecoverSingles(
|
||||||
|
packedBackups, keyRing, &chanRestorer, &peerConnector,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("restoration should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
chanRestorer.fail = false
|
||||||
|
|
||||||
|
// If we make the peer connector fail, then the entire method should as
|
||||||
|
// well
|
||||||
|
peerConnector.fail = true
|
||||||
|
err = UnpackAndRecoverSingles(
|
||||||
|
packedBackups, keyRing, &chanRestorer, &peerConnector,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("restoration should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
chanRestorer.callCount--
|
||||||
|
peerConnector.fail = false
|
||||||
|
|
||||||
|
// Next, we'll ensure that if all the interfaces function as expected,
|
||||||
|
// then the channels will properly be unpacked and restored.
|
||||||
|
err = UnpackAndRecoverSingles(
|
||||||
|
packedBackups, keyRing, &chanRestorer, &peerConnector,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to recover chans: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both the restorer, and connector should have been called 10 times,
|
||||||
|
// once for each backup.
|
||||||
|
if chanRestorer.callCount != numSingles {
|
||||||
|
t.Fatalf("expected %v calls, instead got %v",
|
||||||
|
numSingles, chanRestorer.callCount)
|
||||||
|
}
|
||||||
|
if peerConnector.callCount != numSingles {
|
||||||
|
t.Fatalf("expected %v calls, instead got %v",
|
||||||
|
numSingles, peerConnector.callCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we modify the keyRing, then unpacking should fail.
|
||||||
|
keyRing.fail = true
|
||||||
|
err = UnpackAndRecoverSingles(
|
||||||
|
packedBackups, keyRing, &chanRestorer, &peerConnector,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("unpacking should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(roasbeef): verify proper call args
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestUnpackAndRecoverMulti tests that we're able to properly unpack and
|
||||||
|
// recover a packed multi.
|
||||||
|
func TestUnpackAndRecoverMulti(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
// First, we'll create a number of single chan backups that we'll
|
||||||
|
// shortly back to so we can begin our recovery attempt.
|
||||||
|
numSingles := 10
|
||||||
|
backups := make([]Single, 0, numSingles)
|
||||||
|
for i := 0; i < numSingles; i++ {
|
||||||
|
channel, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable make channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
single := NewSingle(channel, nil)
|
||||||
|
|
||||||
|
backups = append(backups, single)
|
||||||
|
}
|
||||||
|
|
||||||
|
multi := Multi{
|
||||||
|
StaticBackups: backups,
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := multi.PackToWriter(&b, keyRing); err != nil {
|
||||||
|
t.Fatalf("unable to pack multi: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next, we'll pack the set of singles into a packed multi, and also
|
||||||
|
// create the set of interfaces we need to carry out the remainder of
|
||||||
|
// the test.
|
||||||
|
packedMulti := PackedMulti(b.Bytes())
|
||||||
|
|
||||||
|
chanRestorer := mockChannelRestorer{}
|
||||||
|
peerConnector := mockPeerConnector{}
|
||||||
|
|
||||||
|
// If we make the channel restore fail, then the entire method should
|
||||||
|
// as well
|
||||||
|
chanRestorer.fail = true
|
||||||
|
err := UnpackAndRecoverMulti(
|
||||||
|
packedMulti, keyRing, &chanRestorer, &peerConnector,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("restoration should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
chanRestorer.fail = false
|
||||||
|
|
||||||
|
// If we make the peer connector fail, then the entire method should as
|
||||||
|
// well
|
||||||
|
peerConnector.fail = true
|
||||||
|
err = UnpackAndRecoverMulti(
|
||||||
|
packedMulti, keyRing, &chanRestorer, &peerConnector,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("restoration should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
chanRestorer.callCount--
|
||||||
|
peerConnector.fail = false
|
||||||
|
|
||||||
|
// Next, we'll ensure that if all the interfaces function as expected,
|
||||||
|
// then the channels will properly be unpacked and restored.
|
||||||
|
err = UnpackAndRecoverMulti(
|
||||||
|
packedMulti, keyRing, &chanRestorer, &peerConnector,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to recover chans: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both the restorer, and connector should have been called 10 times,
|
||||||
|
// once for each backup.
|
||||||
|
if chanRestorer.callCount != numSingles {
|
||||||
|
t.Fatalf("expected %v calls, instead got %v",
|
||||||
|
numSingles, chanRestorer.callCount)
|
||||||
|
}
|
||||||
|
if peerConnector.callCount != numSingles {
|
||||||
|
t.Fatalf("expected %v calls, instead got %v",
|
||||||
|
numSingles, peerConnector.callCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we modify the keyRing, then unpacking should fail.
|
||||||
|
keyRing.fail = true
|
||||||
|
err = UnpackAndRecoverMulti(
|
||||||
|
packedMulti, keyRing, &chanRestorer, &peerConnector,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("unpacking should have failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(roasbeef): verify proper call args
|
||||||
|
}
|
346
chanbackup/single.go
Normal file
346
chanbackup/single.go
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
|
"github.com/btcsuite/btcd/wire"
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SingleBackupVersion denotes the version of the single static channel backup.
|
||||||
|
// Based on this version, we know how to pack/unpack serialized versions of the
|
||||||
|
// backup.
|
||||||
|
type SingleBackupVersion byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultSingleVersion is the defautl version of the single channel
|
||||||
|
// backup. The seralized version of this static channel backup is
|
||||||
|
// simply: version || SCB. Where SCB is the known format of the
|
||||||
|
// version.
|
||||||
|
DefaultSingleVersion = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
// Single is a static description of an existing channel that can be used for
|
||||||
|
// the purposes of backing up. The fields in this struct allow a node to
|
||||||
|
// recover the settled funds within a channel in the case of partial or
|
||||||
|
// complete data loss. We provide the network address that we last used to
|
||||||
|
// connect to the peer as well, in case the node stops advertising the IP on
|
||||||
|
// the network for whatever reason.
|
||||||
|
//
|
||||||
|
// TODO(roasbeef): suffix version into struct?
|
||||||
|
type Single struct {
|
||||||
|
// Version is the version that should be observed when attempting to
|
||||||
|
// pack the single backup.
|
||||||
|
Version SingleBackupVersion
|
||||||
|
|
||||||
|
// ChainHash is a hash which represents the blockchain that this
|
||||||
|
// channel will be opened within. This value is typically the genesis
|
||||||
|
// hash. In the case that the original chain went through a contentious
|
||||||
|
// hard-fork, then this value will be tweaked using the unique fork
|
||||||
|
// point on each branch.
|
||||||
|
ChainHash chainhash.Hash
|
||||||
|
|
||||||
|
// FundingOutpoint is the outpoint of the final funding transaction.
|
||||||
|
// This value uniquely and globally identities the channel within the
|
||||||
|
// target blockchain as specified by the chain hash parameter.
|
||||||
|
FundingOutpoint wire.OutPoint
|
||||||
|
|
||||||
|
// ShortChannelID encodes the exact location in the chain in which the
|
||||||
|
// channel was initially confirmed. This includes: the block height,
|
||||||
|
// transaction index, and the output within the target transaction.
|
||||||
|
ShortChannelID lnwire.ShortChannelID
|
||||||
|
|
||||||
|
// RemoteNodePub is the identity public key of the remote node this
|
||||||
|
// channel has been established with.
|
||||||
|
RemoteNodePub *btcec.PublicKey
|
||||||
|
|
||||||
|
// Addresses is a list of IP address in which either we were able to
|
||||||
|
// reach the node over in the past, OR we received an incoming
|
||||||
|
// authenticated connection for the stored identity public key.
|
||||||
|
Addresses []net.Addr
|
||||||
|
|
||||||
|
// CsvDelay is the local CSV delay used within the channel. We may need
|
||||||
|
// this value to reconstruct our script to recover the funds on-chain
|
||||||
|
// after a force close.
|
||||||
|
CsvDelay uint16
|
||||||
|
|
||||||
|
// PaymentBasePoint describes how to derive base public that's used to
|
||||||
|
// deriving the key used within the non-delayed pay-to-self output on
|
||||||
|
// the commitment transaction for a node. With this information, we can
|
||||||
|
// re-derive the private key needed to sweep the funds on-chain.
|
||||||
|
PaymentBasePoint keychain.KeyLocator
|
||||||
|
|
||||||
|
// ShaChainRootDesc describes how to derive the private key that was
|
||||||
|
// used as the shachain root for this channel.
|
||||||
|
ShaChainRootDesc keychain.KeyDescriptor
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSingle creates a new static channel backup based on an existing open
|
||||||
|
// channel. We also pass in the set of addresses that we used in the past to
|
||||||
|
// connect to the channel peer.
|
||||||
|
func NewSingle(channel *channeldb.OpenChannel,
|
||||||
|
nodeAddrs []net.Addr) Single {
|
||||||
|
|
||||||
|
chanCfg := channel.LocalChanCfg
|
||||||
|
|
||||||
|
// TODO(roasbeef): update after we start to store the KeyLoc for
|
||||||
|
// shachain root
|
||||||
|
|
||||||
|
// We'll need to obtain the shachain root which is derived directly
|
||||||
|
// from a private key in our keychain.
|
||||||
|
var b bytes.Buffer
|
||||||
|
channel.RevocationProducer.Encode(&b) // Can't return an error.
|
||||||
|
|
||||||
|
// Once we have the root, we'll make a public key from it, such that
|
||||||
|
// the backups plaintext don't carry any private information. When we
|
||||||
|
// go to recover, we'll present this in order to derive the private
|
||||||
|
// key.
|
||||||
|
_, shaChainPoint := btcec.PrivKeyFromBytes(btcec.S256(), b.Bytes())
|
||||||
|
|
||||||
|
return Single{
|
||||||
|
ChainHash: channel.ChainHash,
|
||||||
|
FundingOutpoint: channel.FundingOutpoint,
|
||||||
|
ShortChannelID: channel.ShortChannelID,
|
||||||
|
RemoteNodePub: channel.IdentityPub,
|
||||||
|
Addresses: nodeAddrs,
|
||||||
|
CsvDelay: chanCfg.CsvDelay,
|
||||||
|
PaymentBasePoint: chanCfg.PaymentBasePoint.KeyLocator,
|
||||||
|
ShaChainRootDesc: keychain.KeyDescriptor{
|
||||||
|
PubKey: shaChainPoint,
|
||||||
|
KeyLocator: keychain.KeyLocator{
|
||||||
|
Family: keychain.KeyFamilyRevocationRoot,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize attempts to write out the serialized version of the target
|
||||||
|
// StaticChannelBackup into the passed io.Writer.
|
||||||
|
func (s *Single) Serialize(w io.Writer) error {
|
||||||
|
// Check to ensure that we'll only attempt to serialize a version that
|
||||||
|
// we're aware of.
|
||||||
|
switch s.Version {
|
||||||
|
case DefaultSingleVersion:
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unable to serialize w/ unknown "+
|
||||||
|
"version: %v", s.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the sha chain root has specified a public key (which is
|
||||||
|
// optional), then we'll encode it now.
|
||||||
|
var shaChainPub [33]byte
|
||||||
|
if s.ShaChainRootDesc.PubKey != nil {
|
||||||
|
copy(
|
||||||
|
shaChainPub[:],
|
||||||
|
s.ShaChainRootDesc.PubKey.SerializeCompressed(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First we gather the SCB as is into a temporary buffer so we can
|
||||||
|
// determine the total length. Before we write out the serialized SCB,
|
||||||
|
// we write the length which allows us to skip any Singles that we
|
||||||
|
// don't know of when decoding a multi.
|
||||||
|
var singleBytes bytes.Buffer
|
||||||
|
if err := lnwire.WriteElements(
|
||||||
|
&singleBytes,
|
||||||
|
s.ChainHash[:],
|
||||||
|
s.FundingOutpoint,
|
||||||
|
s.ShortChannelID,
|
||||||
|
s.RemoteNodePub,
|
||||||
|
s.Addresses,
|
||||||
|
s.CsvDelay,
|
||||||
|
uint32(s.PaymentBasePoint.Family),
|
||||||
|
s.PaymentBasePoint.Index,
|
||||||
|
shaChainPub[:],
|
||||||
|
uint32(s.ShaChainRootDesc.KeyLocator.Family),
|
||||||
|
s.ShaChainRootDesc.KeyLocator.Index,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return lnwire.WriteElements(
|
||||||
|
w,
|
||||||
|
byte(s.Version),
|
||||||
|
uint16(len(singleBytes.Bytes())),
|
||||||
|
singleBytes.Bytes(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PackToWriter is similar to the Serialize method, but takes the operation a
|
||||||
|
// step further by encryption the raw bytes of the static channel back up. For
|
||||||
|
// encryption we use the chacah20poly1305 AEAD cipher with a 24 byte nonce and
|
||||||
|
// 32-byte key size. We use a 24-byte nonce, as we can't ensure that we have a
|
||||||
|
// global counter to use as a sequence number for nonces, and want to ensure
|
||||||
|
// that we're able to decrypt these blobs without any additional context. We
|
||||||
|
// derive the key that we use for encryption via a SHA2 operation of the with
|
||||||
|
// the golden keychain.KeyFamilyStaticBackup base encryption key. We then take
|
||||||
|
// the serialized resulting shared secret point, and hash it using sha256 to
|
||||||
|
// obtain the key that we'll use for encryption. When using the AEAD, we pass
|
||||||
|
// the nonce as associated data such that we'll be able to package the two
|
||||||
|
// together for storage. Before writing out the encrypted payload, we prepend
|
||||||
|
// the nonce to the final blob.
|
||||||
|
func (s *Single) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error {
|
||||||
|
// First, we'll serialize the SCB (StaticChannelBackup) into a
|
||||||
|
// temporary buffer so we can store it in a temporary place before we
|
||||||
|
// go to encrypt the entire thing.
|
||||||
|
var rawBytes bytes.Buffer
|
||||||
|
if err := s.Serialize(&rawBytes); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, we'll encrypt the raw serialized SCB (using the nonce as
|
||||||
|
// associated data), and write out the ciphertext prepend with the
|
||||||
|
// nonce that we used to the passed io.Reader.
|
||||||
|
return encryptPayloadToWriter(rawBytes, w, keyRing)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deserialize attempts to read the raw plaintext serialized SCB from the
|
||||||
|
// passed io.Reader. If the method is successful, then the target
|
||||||
|
// StaticChannelBackup will be fully populated.
|
||||||
|
func (s *Single) Deserialize(r io.Reader) error {
|
||||||
|
// First, we'll need to read the version of this single-back up so we
|
||||||
|
// can know how to unpack each of the SCB.
|
||||||
|
var version byte
|
||||||
|
err := lnwire.ReadElements(r, &version)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Version = SingleBackupVersion(version)
|
||||||
|
|
||||||
|
switch s.Version {
|
||||||
|
case DefaultSingleVersion:
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unable to de-serialize w/ unknown "+
|
||||||
|
"version: %v", s.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
var length uint16
|
||||||
|
if err := lnwire.ReadElements(r, &length); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = lnwire.ReadElements(
|
||||||
|
r, s.ChainHash[:], &s.FundingOutpoint, &s.ShortChannelID,
|
||||||
|
&s.RemoteNodePub, &s.Addresses, &s.CsvDelay,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var keyFam uint32
|
||||||
|
if err := lnwire.ReadElements(r, &keyFam); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.PaymentBasePoint.Family = keychain.KeyFamily(keyFam)
|
||||||
|
|
||||||
|
err = lnwire.ReadElements(r, &s.PaymentBasePoint.Index)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, we'll parse out the ShaChainRootDesc.
|
||||||
|
var (
|
||||||
|
shaChainPub [33]byte
|
||||||
|
zeroPub [33]byte
|
||||||
|
)
|
||||||
|
if err := lnwire.ReadElements(r, shaChainPub[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since this field is optional, we'll check to see if the pubkey has
|
||||||
|
// ben specified or not.
|
||||||
|
if !bytes.Equal(shaChainPub[:], zeroPub[:]) {
|
||||||
|
s.ShaChainRootDesc.PubKey, err = btcec.ParsePubKey(
|
||||||
|
shaChainPub[:], btcec.S256(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var shaKeyFam uint32
|
||||||
|
if err := lnwire.ReadElements(r, &shaKeyFam); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.ShaChainRootDesc.KeyLocator.Family = keychain.KeyFamily(shaKeyFam)
|
||||||
|
|
||||||
|
return lnwire.ReadElements(r, &s.ShaChainRootDesc.KeyLocator.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnpackFromReader is similar to Deserialize method, but it expects the passed
|
||||||
|
// io.Reader to contain an encrypt SCB. Refer to the SerializeAndEncrypt method
|
||||||
|
// for details w.r.t the encryption scheme used. If we're unable to decrypt the
|
||||||
|
// payload for whatever reason (wrong key, wrong nonce, etc), then this method
|
||||||
|
// will return an error.
|
||||||
|
func (s *Single) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error {
|
||||||
|
plaintext, err := decryptPayloadFromReader(r, keyRing)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, we'll pack the bytes into a reader to we can deserialize
|
||||||
|
// the plaintext bytes of the SCB.
|
||||||
|
backupReader := bytes.NewReader(plaintext)
|
||||||
|
return s.Deserialize(backupReader)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PackStaticChanBackups accepts a set of existing open channels, and a
|
||||||
|
// keychain.KeyRing, and returns a map of outpoints to the serialized+encrypted
|
||||||
|
// static channel backups. The passed keyRing should be backed by the users
|
||||||
|
// root HD seed in order to ensure full determinism.
|
||||||
|
func PackStaticChanBackups(backups []Single,
|
||||||
|
keyRing keychain.KeyRing) (map[wire.OutPoint][]byte, error) {
|
||||||
|
|
||||||
|
packedBackups := make(map[wire.OutPoint][]byte)
|
||||||
|
for _, chanBackup := range backups {
|
||||||
|
chanPoint := chanBackup.FundingOutpoint
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
err := chanBackup.PackToWriter(&b, keyRing)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to pack chan backup "+
|
||||||
|
"for %v: %v", chanPoint, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
packedBackups[chanPoint] = b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
return packedBackups, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PackedSingles represents a series of fully packed SCBs. This may be the
|
||||||
|
// combination of a series of individual SCBs in order to batch their
|
||||||
|
// unpacking.
|
||||||
|
type PackedSingles [][]byte
|
||||||
|
|
||||||
|
// Unpack attempts to decrypt the passed set of encrypted SCBs and deserialize
|
||||||
|
// each one into a new SCB struct. The passed keyRing should be backed by the
|
||||||
|
// same HD seed as was used to encrypt the set of backups in the first place.
|
||||||
|
// If we're unable to decrypt any of the back ups, then we'll return an error.
|
||||||
|
func (p PackedSingles) Unpack(keyRing keychain.KeyRing) ([]Single, error) {
|
||||||
|
|
||||||
|
backups := make([]Single, len(p))
|
||||||
|
for i, encryptedBackup := range p {
|
||||||
|
var backup Single
|
||||||
|
|
||||||
|
backupReader := bytes.NewReader(encryptedBackup)
|
||||||
|
err := backup.UnpackFromReader(backupReader, keyRing)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
backups[i] = backup
|
||||||
|
}
|
||||||
|
|
||||||
|
return backups, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(roasbeef): make codec package?
|
342
chanbackup/single_test.go
Normal file
342
chanbackup/single_test.go
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
package chanbackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcec"
|
||||||
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
|
"github.com/btcsuite/btcd/wire"
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/shachain"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
chainHash = chainhash.Hash{
|
||||||
|
0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab,
|
||||||
|
0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4,
|
||||||
|
0x4f, 0x2f, 0x6f, 0x25, 0x18, 0xa3, 0xef, 0xb9,
|
||||||
|
0x64, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53,
|
||||||
|
}
|
||||||
|
|
||||||
|
op = wire.OutPoint{
|
||||||
|
Hash: chainHash,
|
||||||
|
Index: 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
addr1, _ = net.ResolveTCPAddr("tcp", "10.0.0.2:9000")
|
||||||
|
addr2, _ = net.ResolveTCPAddr("tcp", "10.0.0.3:9000")
|
||||||
|
)
|
||||||
|
|
||||||
|
func assertSingleEqual(t *testing.T, a, b Single) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if a.Version != b.Version {
|
||||||
|
t.Fatalf("versions don't match: %v vs %v", a.Version,
|
||||||
|
b.Version)
|
||||||
|
}
|
||||||
|
if a.ChainHash != b.ChainHash {
|
||||||
|
t.Fatalf("chainhash doesn't match: %v vs %v", a.ChainHash,
|
||||||
|
b.ChainHash)
|
||||||
|
}
|
||||||
|
if a.FundingOutpoint != b.FundingOutpoint {
|
||||||
|
t.Fatalf("chan point doesn't match: %v vs %v",
|
||||||
|
a.FundingOutpoint, b.FundingOutpoint)
|
||||||
|
}
|
||||||
|
if a.ShortChannelID != b.ShortChannelID {
|
||||||
|
t.Fatalf("chan id doesn't match: %v vs %v",
|
||||||
|
a.ShortChannelID, b.ShortChannelID)
|
||||||
|
}
|
||||||
|
if !a.RemoteNodePub.IsEqual(b.RemoteNodePub) {
|
||||||
|
t.Fatalf("node pubs don't match %x vs %x",
|
||||||
|
a.RemoteNodePub.SerializeCompressed(),
|
||||||
|
b.RemoteNodePub.SerializeCompressed())
|
||||||
|
}
|
||||||
|
if a.CsvDelay != b.CsvDelay {
|
||||||
|
t.Fatalf("csv delay doesn't match: %v vs %v", a.CsvDelay,
|
||||||
|
b.CsvDelay)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(a.PaymentBasePoint, b.PaymentBasePoint) {
|
||||||
|
t.Fatalf("base point doesn't match: %v vs %v",
|
||||||
|
spew.Sdump(a.PaymentBasePoint),
|
||||||
|
spew.Sdump(b.PaymentBasePoint))
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(a.ShaChainRootDesc, b.ShaChainRootDesc) {
|
||||||
|
t.Fatalf("sha chain point doesn't match: %v vs %v",
|
||||||
|
spew.Sdump(a.PaymentBasePoint),
|
||||||
|
spew.Sdump(b.PaymentBasePoint))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(a.Addresses) != len(b.Addresses) {
|
||||||
|
t.Fatalf("expected %v addrs got %v", len(a.Addresses),
|
||||||
|
len(b.Addresses))
|
||||||
|
}
|
||||||
|
for i := 0; i < len(a.Addresses); i++ {
|
||||||
|
if a.Addresses[i].String() != b.Addresses[i].String() {
|
||||||
|
t.Fatalf("addr mismatch: %v vs %v",
|
||||||
|
a.Addresses[i], b.Addresses[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func genRandomOpenChannelShell() (*channeldb.OpenChannel, error) {
|
||||||
|
var testPriv [32]byte
|
||||||
|
if _, err := rand.Read(testPriv[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:])
|
||||||
|
|
||||||
|
var chanPoint wire.OutPoint
|
||||||
|
if _, err := rand.Read(chanPoint.Hash[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pub.Curve = nil
|
||||||
|
|
||||||
|
chanPoint.Index = uint32(rand.Intn(math.MaxUint16))
|
||||||
|
|
||||||
|
var shaChainRoot [32]byte
|
||||||
|
if _, err := rand.Read(shaChainRoot[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
shaChainProducer := shachain.NewRevocationProducer(shaChainRoot)
|
||||||
|
|
||||||
|
return &channeldb.OpenChannel{
|
||||||
|
ChainHash: chainHash,
|
||||||
|
FundingOutpoint: chanPoint,
|
||||||
|
ShortChannelID: lnwire.NewShortChanIDFromInt(
|
||||||
|
uint64(rand.Int63()),
|
||||||
|
),
|
||||||
|
IdentityPub: pub,
|
||||||
|
LocalChanCfg: channeldb.ChannelConfig{
|
||||||
|
ChannelConstraints: channeldb.ChannelConstraints{
|
||||||
|
CsvDelay: uint16(rand.Int63()),
|
||||||
|
},
|
||||||
|
PaymentBasePoint: keychain.KeyDescriptor{
|
||||||
|
KeyLocator: keychain.KeyLocator{
|
||||||
|
Family: keychain.KeyFamily(rand.Int63()),
|
||||||
|
Index: uint32(rand.Int63()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
RevocationProducer: shaChainProducer,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSinglePackUnpack tests that we're able to unpack a previously packed
|
||||||
|
// channel backup.
|
||||||
|
func TestSinglePackUnpack(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Given our test pub key, we'll create an open channel shell that
|
||||||
|
// contains all the information we need to create a static channel
|
||||||
|
// backup.
|
||||||
|
channel, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to gen open channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
singleChanBackup := NewSingle(channel, []net.Addr{addr1, addr2})
|
||||||
|
singleChanBackup.RemoteNodePub.Curve = nil
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
versionTestCases := []struct {
|
||||||
|
// version is the pack/unpack version that we should use to
|
||||||
|
// decode/encode the final SCB.
|
||||||
|
version SingleBackupVersion
|
||||||
|
|
||||||
|
// valid tests us if this test case should pass or not.
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
// The default version, should pack/unpack with no problem.
|
||||||
|
{
|
||||||
|
version: DefaultSingleVersion,
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
|
||||||
|
// A non-default version, atm this should result in a failure.
|
||||||
|
{
|
||||||
|
version: 99,
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, versionCase := range versionTestCases {
|
||||||
|
// First, we'll re-assign SCB version to what was indicated in
|
||||||
|
// the test case.
|
||||||
|
singleChanBackup.Version = versionCase.version
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
|
||||||
|
err := singleChanBackup.PackToWriter(&b, keyRing)
|
||||||
|
switch {
|
||||||
|
// If this is a valid test case, and we failed, then we'll
|
||||||
|
// return an error.
|
||||||
|
case err != nil && versionCase.valid:
|
||||||
|
t.Fatalf("#%v, unable to pack single: %v", i, err)
|
||||||
|
|
||||||
|
// If this is an invalid test case, and we passed it, then
|
||||||
|
// we'll return an error.
|
||||||
|
case err == nil && !versionCase.valid:
|
||||||
|
t.Fatalf("#%v got nil error for invalid pack: %v",
|
||||||
|
i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is a valid test case, then we'll continue to ensure
|
||||||
|
// we can unpack it, and also that if we mutate the packed
|
||||||
|
// version, then we trigger an error.
|
||||||
|
if versionCase.valid {
|
||||||
|
var unpackedSingle Single
|
||||||
|
err = unpackedSingle.UnpackFromReader(&b, keyRing)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("#%v unable to unpack single: %v",
|
||||||
|
i, err)
|
||||||
|
}
|
||||||
|
unpackedSingle.RemoteNodePub.Curve = nil
|
||||||
|
|
||||||
|
assertSingleEqual(t, singleChanBackup, unpackedSingle)
|
||||||
|
|
||||||
|
// If this was a valid packing attempt, then we'll test
|
||||||
|
// to ensure that if we mutate the version prepended to
|
||||||
|
// the serialization, then unpacking will fail as well.
|
||||||
|
var rawSingle bytes.Buffer
|
||||||
|
err := unpackedSingle.Serialize(&rawSingle)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to serialize single: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBytes := rawSingle.Bytes()
|
||||||
|
rawBytes[0] ^= 1
|
||||||
|
|
||||||
|
newReader := bytes.NewReader(rawBytes)
|
||||||
|
err = unpackedSingle.Deserialize(newReader)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("#%v unpack with unknown version "+
|
||||||
|
"should have failed", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPackedSinglesUnpack tests that we're able to properly unpack a series of
|
||||||
|
// packed singles.
|
||||||
|
func TestPackedSinglesUnpack(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
// To start, we'll create 10 new singles, and them assemble their
|
||||||
|
// packed forms into a slice.
|
||||||
|
numSingles := 10
|
||||||
|
packedSingles := make([][]byte, 0, numSingles)
|
||||||
|
unpackedSingles := make([]Single, 0, numSingles)
|
||||||
|
for i := 0; i < numSingles; i++ {
|
||||||
|
channel, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to gen channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
single := NewSingle(channel, nil)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := single.PackToWriter(&b, keyRing); err != nil {
|
||||||
|
t.Fatalf("unable to pack single: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
packedSingles = append(packedSingles, b.Bytes())
|
||||||
|
unpackedSingles = append(unpackedSingles, single)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With all singles packed, we'll create the grouped type and attempt
|
||||||
|
// to Unpack all of them in a single go.
|
||||||
|
freshSingles, err := PackedSingles(packedSingles).Unpack(keyRing)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to unpack singles: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The set of freshly unpacked singles should exactly match the initial
|
||||||
|
// set of singles that we packed before.
|
||||||
|
for i := 0; i < len(unpackedSingles); i++ {
|
||||||
|
assertSingleEqual(t, unpackedSingles[i], freshSingles[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we mutate one of the packed singles, then the entire method
|
||||||
|
// should fail.
|
||||||
|
packedSingles[0][0] ^= 1
|
||||||
|
_, err = PackedSingles(packedSingles).Unpack(keyRing)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("unpack attempt should fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSinglePackStaticChanBackups tests that we're able to batch pack a set of
|
||||||
|
// Singles, and then unpack them obtaining the same set of unpacked singles.
|
||||||
|
func TestSinglePackStaticChanBackups(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
keyRing := &mockKeyRing{}
|
||||||
|
|
||||||
|
// First, we'll create a set of random single, and along the way,
|
||||||
|
// create a map that will let us look up each single by its chan point.
|
||||||
|
numSingles := 10
|
||||||
|
singleMap := make(map[wire.OutPoint]Single, numSingles)
|
||||||
|
unpackedSingles := make([]Single, 0, numSingles)
|
||||||
|
for i := 0; i < numSingles; i++ {
|
||||||
|
channel, err := genRandomOpenChannelShell()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to gen channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
single := NewSingle(channel, nil)
|
||||||
|
|
||||||
|
singleMap[channel.FundingOutpoint] = single
|
||||||
|
unpackedSingles = append(unpackedSingles, single)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we have all of our singles are created, we'll attempt to
|
||||||
|
// pack them all in a single batch.
|
||||||
|
packedSingleMap, err := PackStaticChanBackups(unpackedSingles, keyRing)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to pack backups: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With our packed singles obtained, we'll ensure that each of them
|
||||||
|
// match their unpacked counterparts after they themselves have been
|
||||||
|
// unpacked.
|
||||||
|
for chanPoint, single := range singleMap {
|
||||||
|
packedSingles, ok := packedSingleMap[chanPoint]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("unable to find single %v", chanPoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
var freshSingle Single
|
||||||
|
err := freshSingle.UnpackFromReader(
|
||||||
|
bytes.NewReader(packedSingles), keyRing,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to unpack single: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
assertSingleEqual(t, single, freshSingle)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we attempt to pack again, but force the key ring to fail, then
|
||||||
|
// the entire method should fail.
|
||||||
|
_, err = PackStaticChanBackups(
|
||||||
|
unpackedSingles, &mockKeyRing{true},
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("pack attempt should fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(roasbsef): fuzz parsing
|
@ -83,6 +83,13 @@ const (
|
|||||||
// in order to establish a transport session with us on the Lightning
|
// in order to establish a transport session with us on the Lightning
|
||||||
// p2p level (BOLT-0008).
|
// p2p level (BOLT-0008).
|
||||||
KeyFamilyNodeKey KeyFamily = 6
|
KeyFamilyNodeKey KeyFamily = 6
|
||||||
|
|
||||||
|
// KeyFamilyStaticBackup is the family of keys that will be used to
|
||||||
|
// derive keys that we use to encrypt and decrypt our set of static
|
||||||
|
// backups. These backups may either be stored within watch towers for
|
||||||
|
// a payment, or self stored on disk in a single file containing all
|
||||||
|
// the static channel backups.
|
||||||
|
KeyFamilyStaticBackup KeyFamily = 7
|
||||||
)
|
)
|
||||||
|
|
||||||
// KeyLocator is a two-tuple that can be used to derive *any* key that has ever
|
// KeyLocator is a two-tuple that can be used to derive *any* key that has ever
|
||||||
|
Loading…
Reference in New Issue
Block a user