Merge pull request #2370 from Roasbeef/static-chan-backups-chanbackup
chanbackup: add new package implementing static channel backups
This commit is contained in:
commit
e9889cb899
99
chanbackup/backup.go
Normal file
99
chanbackup/backup.go
Normal file
@ -0,0 +1,99 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
)
|
||||
|
||||
// LiveChannelSource is an interface that allows us to query for the set of
|
||||
// live channels. A live channel is one that is open, and has not had a
|
||||
// commitment transaction broadcast.
|
||||
type LiveChannelSource interface {
|
||||
// FetchAllChannels returns all known live channels.
|
||||
FetchAllChannels() ([]*channeldb.OpenChannel, error)
|
||||
|
||||
// FetchChannel attempts to locate a live channel identified by the
|
||||
// passed chanPoint.
|
||||
FetchChannel(chanPoint wire.OutPoint) (*channeldb.OpenChannel, error)
|
||||
|
||||
// AddrsForNode returns all known addresses for the target node public
|
||||
// key.
|
||||
AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error)
|
||||
}
|
||||
|
||||
// assembleChanBackup attempts to assemble a static channel backup for the
|
||||
// passed open channel. The backup includes all information required to restore
|
||||
// the channel, as well as addressing information so we can find the peer and
|
||||
// reconnect to them to initiate the protocol.
|
||||
func assembleChanBackup(chanSource LiveChannelSource,
|
||||
openChan *channeldb.OpenChannel) (*Single, error) {
|
||||
|
||||
log.Debugf("Crafting backup for ChannelPoint(%v)",
|
||||
openChan.FundingOutpoint)
|
||||
|
||||
// First, we'll query the channel source to obtain all the addresses
|
||||
// that are are associated with the peer for this channel.
|
||||
nodeAddrs, err := chanSource.AddrsForNode(openChan.IdentityPub)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
single := NewSingle(openChan, nodeAddrs)
|
||||
|
||||
return &single, nil
|
||||
}
|
||||
|
||||
// FetchBackupForChan attempts to create a plaintext static channel backup for
|
||||
// the target channel identified by its channel point. If we're unable to find
|
||||
// the target channel, then an error will be returned.
|
||||
func FetchBackupForChan(chanPoint wire.OutPoint,
|
||||
chanSource LiveChannelSource) (*Single, error) {
|
||||
|
||||
// First, we'll query the channel source to see if the channel is known
|
||||
// and open within the database.
|
||||
targetChan, err := chanSource.FetchChannel(chanPoint)
|
||||
if err != nil {
|
||||
// If we can't find the channel, then we return with an error,
|
||||
// as we have nothing to backup.
|
||||
return nil, fmt.Errorf("unable to find target channel")
|
||||
}
|
||||
|
||||
// Once we have the target channel, we can assemble the backup using
|
||||
// the source to obtain any extra information that we may need.
|
||||
staticChanBackup, err := assembleChanBackup(chanSource, targetChan)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create chan backup: %v", err)
|
||||
}
|
||||
|
||||
return staticChanBackup, nil
|
||||
}
|
||||
|
||||
// FetchStaticChanBackups will return a plaintext static channel back up for
|
||||
// all known active/open channels within the passed channel source.
|
||||
func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) {
|
||||
// First, we'll query the backup source for information concerning all
|
||||
// currently open and available channels.
|
||||
openChans, err := chanSource.FetchAllChannels()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Now that we have all the channels, we'll use the chanSource to
|
||||
// obtain any auxiliary information we need to craft a backup for each
|
||||
// channel.
|
||||
staticChanBackups := make([]Single, len(openChans))
|
||||
for i, openChan := range openChans {
|
||||
chanBackup, err := assembleChanBackup(chanSource, openChan)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
staticChanBackups[i] = *chanBackup
|
||||
}
|
||||
|
||||
return staticChanBackups, nil
|
||||
}
|
197
chanbackup/backup_test.go
Normal file
197
chanbackup/backup_test.go
Normal file
@ -0,0 +1,197 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
)
|
||||
|
||||
type mockChannelSource struct {
|
||||
chans map[wire.OutPoint]*channeldb.OpenChannel
|
||||
|
||||
failQuery bool
|
||||
|
||||
addrs map[[33]byte][]net.Addr
|
||||
}
|
||||
|
||||
func newMockChannelSource() *mockChannelSource {
|
||||
return &mockChannelSource{
|
||||
chans: make(map[wire.OutPoint]*channeldb.OpenChannel),
|
||||
addrs: make(map[[33]byte][]net.Addr),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockChannelSource) FetchAllChannels() ([]*channeldb.OpenChannel, error) {
|
||||
if m.failQuery {
|
||||
return nil, fmt.Errorf("fail")
|
||||
}
|
||||
|
||||
chans := make([]*channeldb.OpenChannel, 0, len(m.chans))
|
||||
for _, channel := range m.chans {
|
||||
chans = append(chans, channel)
|
||||
}
|
||||
|
||||
return chans, nil
|
||||
}
|
||||
|
||||
func (m *mockChannelSource) FetchChannel(chanPoint wire.OutPoint) (*channeldb.OpenChannel, error) {
|
||||
if m.failQuery {
|
||||
return nil, fmt.Errorf("fail")
|
||||
}
|
||||
|
||||
channel, ok := m.chans[chanPoint]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can't find chan")
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func (m *mockChannelSource) addAddrsForNode(nodePub *btcec.PublicKey, addrs []net.Addr) {
|
||||
var nodeKey [33]byte
|
||||
copy(nodeKey[:], nodePub.SerializeCompressed())
|
||||
|
||||
m.addrs[nodeKey] = addrs
|
||||
}
|
||||
|
||||
func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
|
||||
if m.failQuery {
|
||||
return nil, fmt.Errorf("fail")
|
||||
}
|
||||
|
||||
var nodeKey [33]byte
|
||||
copy(nodeKey[:], nodePub.SerializeCompressed())
|
||||
|
||||
addrs, ok := m.addrs[nodeKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can't find addr")
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// TestFetchBackupForChan tests that we're able to construct a single channel
|
||||
// backup for channels that are known, unknown, and also channels in which we
|
||||
// can find addresses for and otherwise.
|
||||
func TestFetchBackupForChan(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First, we'll make two channels, only one of them will have all the
|
||||
// information we need to construct set of backups for them.
|
||||
randomChan1, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate chan: %v", err)
|
||||
}
|
||||
randomChan2, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate chan: %v", err)
|
||||
}
|
||||
|
||||
chanSource := newMockChannelSource()
|
||||
chanSource.chans[randomChan1.FundingOutpoint] = randomChan1
|
||||
chanSource.chans[randomChan2.FundingOutpoint] = randomChan2
|
||||
|
||||
chanSource.addAddrsForNode(randomChan1.IdentityPub, []net.Addr{addr1})
|
||||
|
||||
testCases := []struct {
|
||||
chanPoint wire.OutPoint
|
||||
|
||||
pass bool
|
||||
}{
|
||||
// Able to find channel, and addresses, should pass.
|
||||
{
|
||||
chanPoint: randomChan1.FundingOutpoint,
|
||||
pass: true,
|
||||
},
|
||||
|
||||
// Able to find channel, not able to find addrs, should fail.
|
||||
{
|
||||
chanPoint: randomChan2.FundingOutpoint,
|
||||
pass: false,
|
||||
},
|
||||
|
||||
// Not able to find channel, should fail.
|
||||
{
|
||||
chanPoint: op,
|
||||
pass: false,
|
||||
},
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
_, err := FetchBackupForChan(testCase.chanPoint, chanSource)
|
||||
switch {
|
||||
// If this is a valid test case, and we failed, then we'll
|
||||
// return an error.
|
||||
case err != nil && testCase.pass:
|
||||
t.Fatalf("#%v, unable to make chan backup: %v", i, err)
|
||||
|
||||
// If this is an invalid test case, and we passed it, then
|
||||
// we'll return an error.
|
||||
case err == nil && !testCase.pass:
|
||||
t.Fatalf("#%v got nil error for invalid req: %v",
|
||||
i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFetchStaticChanBackups tests that we're able to properly query the
|
||||
// channel source for all channels and construct a Single for each channel.
|
||||
func TestFetchStaticChanBackups(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// First, we'll make the set of channels that we want to seed the
|
||||
// channel source with. Both channels will be fully populated in the
|
||||
// channel source.
|
||||
const numChans = 2
|
||||
randomChan1, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate chan: %v", err)
|
||||
}
|
||||
randomChan2, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to generate chan: %v", err)
|
||||
}
|
||||
|
||||
chanSource := newMockChannelSource()
|
||||
chanSource.chans[randomChan1.FundingOutpoint] = randomChan1
|
||||
chanSource.chans[randomChan2.FundingOutpoint] = randomChan2
|
||||
chanSource.addAddrsForNode(randomChan1.IdentityPub, []net.Addr{addr1})
|
||||
chanSource.addAddrsForNode(randomChan2.IdentityPub, []net.Addr{addr2})
|
||||
|
||||
// With the channel source populated, we'll now attempt to create a set
|
||||
// of backups for all the channels. This should succeed, as all items
|
||||
// are populated within the channel source.
|
||||
backups, err := FetchStaticChanBackups(chanSource)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create chan back ups: %v", err)
|
||||
}
|
||||
|
||||
if len(backups) != numChans {
|
||||
t.Fatalf("expected %v chans, instead got %v", numChans,
|
||||
len(backups))
|
||||
}
|
||||
|
||||
// We'll attempt to create a set up backups again, but this time the
|
||||
// second channel will have missing information, which should cause the
|
||||
// query to fail.
|
||||
var n [33]byte
|
||||
copy(n[:], randomChan2.IdentityPub.SerializeCompressed())
|
||||
delete(chanSource.addrs, n)
|
||||
|
||||
_, err = FetchStaticChanBackups(chanSource)
|
||||
if err == nil {
|
||||
t.Fatalf("query with incomplete information should fail")
|
||||
}
|
||||
|
||||
// To wrap up, we'll ensure that if we're unable to query the channel
|
||||
// source at all, then we'll fail as well.
|
||||
chanSource = newMockChannelSource()
|
||||
chanSource.failQuery = true
|
||||
_, err = FetchStaticChanBackups(chanSource)
|
||||
if err == nil {
|
||||
t.Fatalf("query should fail")
|
||||
}
|
||||
}
|
160
chanbackup/backupfile.go
Normal file
160
chanbackup/backupfile.go
Normal file
@ -0,0 +1,160 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultBackupFileName is the default name of the auto updated static
|
||||
// channel backup fie.
|
||||
DefaultBackupFileName = "channel.backup"
|
||||
|
||||
// DefaultTempBackupFileName is the default name of the temporary SCB
|
||||
// file that we'll use to atomically update the primary back up file
|
||||
// when new channel are detected.
|
||||
DefaultTempBackupFileName = "temp-dont-use.backup"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNoBackupFileExists is returned if caller attempts to call
|
||||
// UpdateAndSwap with the file name not set.
|
||||
ErrNoBackupFileExists = fmt.Errorf("back up file name not set")
|
||||
|
||||
// ErrNoTempBackupFile is returned if caller attempts to call
|
||||
// UpdateAndSwap with the temp back up file name not set.
|
||||
ErrNoTempBackupFile = fmt.Errorf("temp backup file not set")
|
||||
)
|
||||
|
||||
// MultiFile represents a file on disk that a caller can use to read the packed
|
||||
// multi backup into an unpacked one, and also atomically update the contents
|
||||
// on disk once new channels have been opened, and old ones closed. This struct
|
||||
// relies on an atomic file rename property which most widely use file systems
|
||||
// have.
|
||||
type MultiFile struct {
|
||||
// fileName is the file name of the main back up file.
|
||||
fileName string
|
||||
|
||||
// mainFile is an open handle to the main back up file.
|
||||
mainFile *os.File
|
||||
|
||||
// tempFileName is the name of the file that we'll use to stage a new
|
||||
// packed multi-chan backup, and the rename to the main back up file.
|
||||
tempFileName string
|
||||
|
||||
// tempFile is an open handle to the temp back up file.
|
||||
tempFile *os.File
|
||||
}
|
||||
|
||||
// NewMultiFile create a new multi-file instance at the target location on the
|
||||
// file system.
|
||||
func NewMultiFile(fileName string) *MultiFile {
|
||||
|
||||
// We'll our temporary backup file in the very same directory as the
|
||||
// main backup file.
|
||||
backupFileDir := filepath.Dir(fileName)
|
||||
tempFileName := filepath.Join(
|
||||
backupFileDir, DefaultTempBackupFileName,
|
||||
)
|
||||
|
||||
return &MultiFile{
|
||||
fileName: fileName,
|
||||
tempFileName: tempFileName,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateAndSwap will attempt write a new temporary backup file to disk with
|
||||
// the newBackup encoded, then atomically swap (via rename) the old file for
|
||||
// the new file by updating the name of the new file to the old.
|
||||
func (b *MultiFile) UpdateAndSwap(newBackup PackedMulti) error {
|
||||
// If the main backup file isn't set, then we can't proceed.
|
||||
if b.fileName == "" {
|
||||
return ErrNoBackupFileExists
|
||||
}
|
||||
|
||||
// If the old back up file still exists, then we'll delete it before
|
||||
// proceeding.
|
||||
if _, err := os.Stat(b.tempFileName); err == nil {
|
||||
log.Infof("Found old temp backup @ %v, removing before swap",
|
||||
b.tempFileName)
|
||||
|
||||
err = os.Remove(b.tempFileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove temp "+
|
||||
"backup file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we know the staging area is clear, we'll create the new
|
||||
// temporary back up file.
|
||||
var err error
|
||||
b.tempFile, err = os.Create(b.tempFileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// With the file created, we'll write the new packed multi backup and
|
||||
// remove the temporary file all together once this method exits.
|
||||
_, err = b.tempFile.Write([]byte(newBackup))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := b.tempFile.Sync(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(b.tempFileName)
|
||||
|
||||
log.Infof("Swapping old multi backup file from %v to %v",
|
||||
b.tempFileName, b.fileName)
|
||||
|
||||
// Finally, we'll attempt to atomically rename the temporary file to
|
||||
// the main back up file. If this succeeds, then we'll only have a
|
||||
// single file on disk once this method exits.
|
||||
return os.Rename(b.tempFileName, b.fileName)
|
||||
}
|
||||
|
||||
// ExtractMulti attempts to extract the packed multi backup we currently point
|
||||
// to into an unpacked version. This method will fail if no backup file
|
||||
// currently exists as the specified location.
|
||||
func (b *MultiFile) ExtractMulti(keyChain keychain.KeyRing) (*Multi, error) {
|
||||
var err error
|
||||
|
||||
// If the backup file isn't already set, then we'll attempt to open it
|
||||
// anew.
|
||||
if b.mainFile == nil {
|
||||
// We'll return an error if the main file isn't currently set.
|
||||
if b.fileName == "" {
|
||||
return nil, ErrNoBackupFileExists
|
||||
}
|
||||
|
||||
// Otherwise, we'll open the file to prep for reading the
|
||||
// contents.
|
||||
b.mainFile, err = os.Open(b.fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Before we start to read the file, we'll ensure that the next read
|
||||
// call will start from the front of the file.
|
||||
_, err = b.mainFile.Seek(0, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// With our seek successful, we'll now attempt to read the contents of
|
||||
// the entire file in one swoop.
|
||||
multiBytes, err := ioutil.ReadAll(b.mainFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Finally, we'll attempt to unpack the file and return the unpack
|
||||
// version to the caller.
|
||||
packedMulti := PackedMulti(multiBytes)
|
||||
return packedMulti.Unpack(keyChain)
|
||||
}
|
289
chanbackup/backupfile_test.go
Normal file
289
chanbackup/backupfile_test.go
Normal file
@ -0,0 +1,289 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeFakePackedMulti() (PackedMulti, error) {
|
||||
newPackedMulti := make([]byte, 50)
|
||||
if _, err := rand.Read(newPackedMulti[:]); err != nil {
|
||||
return nil, fmt.Errorf("unable to make test backup: %v", err)
|
||||
}
|
||||
|
||||
return PackedMulti(newPackedMulti), nil
|
||||
}
|
||||
|
||||
func assertBackupMatches(t *testing.T, filePath string,
|
||||
currentBackup PackedMulti) {
|
||||
|
||||
t.Helper()
|
||||
|
||||
packedBackup, err := ioutil.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to test file: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(packedBackup, currentBackup) {
|
||||
t.Fatalf("backups don't match after first swap: "+
|
||||
"expected %x got %x", packedBackup[:],
|
||||
currentBackup)
|
||||
}
|
||||
}
|
||||
|
||||
func assertFileDeleted(t *testing.T, filePath string) {
|
||||
t.Helper()
|
||||
|
||||
_, err := os.Stat(filePath)
|
||||
if err == nil {
|
||||
t.Fatalf("file %v still exists: ", filePath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateAndSwap test that we're able to properly swap out old backups on
|
||||
// disk with new ones. Additionally, after a swap operation succeeds, then each
|
||||
// time we should only have the main backup file on disk, as the temporary file
|
||||
// has been removed.
|
||||
func TestUpdateAndSwap(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tempTestDir, err := ioutil.TempDir("", "")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make temp dir: %v", err)
|
||||
}
|
||||
defer os.Remove(tempTestDir)
|
||||
|
||||
testCases := []struct {
|
||||
fileName string
|
||||
tempFileName string
|
||||
|
||||
oldTempExists bool
|
||||
|
||||
valid bool
|
||||
}{
|
||||
// Main file name is blank, should fail.
|
||||
{
|
||||
fileName: "",
|
||||
valid: false,
|
||||
},
|
||||
|
||||
// Old temporary file still exists, should be removed. Only one
|
||||
// file should remain.
|
||||
{
|
||||
fileName: filepath.Join(
|
||||
tempTestDir, DefaultBackupFileName,
|
||||
),
|
||||
tempFileName: filepath.Join(
|
||||
tempTestDir, DefaultTempBackupFileName,
|
||||
),
|
||||
oldTempExists: true,
|
||||
valid: true,
|
||||
},
|
||||
|
||||
// Old temp doesn't exist, should swap out file, only a single
|
||||
// file remains.
|
||||
{
|
||||
fileName: filepath.Join(
|
||||
tempTestDir, DefaultBackupFileName,
|
||||
),
|
||||
tempFileName: filepath.Join(
|
||||
tempTestDir, DefaultTempBackupFileName,
|
||||
),
|
||||
valid: true,
|
||||
},
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
// Ensure that all created files are removed at the end of the
|
||||
// test case.
|
||||
defer os.Remove(testCase.fileName)
|
||||
defer os.Remove(testCase.tempFileName)
|
||||
|
||||
backupFile := NewMultiFile(testCase.fileName)
|
||||
|
||||
// To start with, we'll make a random byte slice that'll pose
|
||||
// as our packed multi backup.
|
||||
newPackedMulti, err := makeFakePackedMulti()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test backup: %v", err)
|
||||
}
|
||||
|
||||
// If the old temporary file is meant to exist, then we'll
|
||||
// create it now as an empty file.
|
||||
if testCase.oldTempExists {
|
||||
_, err := os.Create(testCase.tempFileName)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create temp file: %v", err)
|
||||
}
|
||||
|
||||
// TODO(roasbeef): mock out fs calls?
|
||||
}
|
||||
|
||||
// With our backup created, we'll now attempt to swap out this
|
||||
// backup, for the old one.
|
||||
err = backupFile.UpdateAndSwap(PackedMulti(newPackedMulti))
|
||||
switch {
|
||||
// If this is a valid test case, and we failed, then we'll
|
||||
// return an error.
|
||||
case err != nil && testCase.valid:
|
||||
t.Fatalf("#%v, unable to swap file: %v", i, err)
|
||||
|
||||
// If this is an invalid test case, and we passed it, then
|
||||
// we'll return an error.
|
||||
case err == nil && !testCase.valid:
|
||||
t.Fatalf("#%v file swap should have failed: %v", i, err)
|
||||
}
|
||||
|
||||
if !testCase.valid {
|
||||
continue
|
||||
}
|
||||
|
||||
// If we read out the file on disk, then it should match
|
||||
// exactly what we wrote. The temp backup file should also be
|
||||
// gone.
|
||||
assertBackupMatches(t, testCase.fileName, newPackedMulti)
|
||||
assertFileDeleted(t, testCase.tempFileName)
|
||||
|
||||
// Now that we know this is a valid test case, we'll make a new
|
||||
// packed multi to swap out this current one.
|
||||
newPackedMulti2, err := makeFakePackedMulti()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test backup: %v", err)
|
||||
}
|
||||
|
||||
// We'll then attempt to swap the old version for this new one.
|
||||
err = backupFile.UpdateAndSwap(PackedMulti(newPackedMulti2))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to swap file: %v", err)
|
||||
}
|
||||
|
||||
// Once again, the file written on disk should have been
|
||||
// properly swapped out with the new instance.
|
||||
assertBackupMatches(t, testCase.fileName, newPackedMulti2)
|
||||
|
||||
// Additionally, we shouldn't be able to find the temp backup
|
||||
// file on disk, as it should be deleted each time.
|
||||
assertFileDeleted(t, testCase.tempFileName)
|
||||
}
|
||||
}
|
||||
|
||||
func assertMultiEqual(t *testing.T, a, b *Multi) {
|
||||
|
||||
if len(a.StaticBackups) != len(b.StaticBackups) {
|
||||
t.Fatalf("expected %v backups, got %v", len(a.StaticBackups),
|
||||
len(b.StaticBackups))
|
||||
}
|
||||
|
||||
for i := 0; i < len(a.StaticBackups); i++ {
|
||||
assertSingleEqual(t, a.StaticBackups[i], b.StaticBackups[i])
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractMulti tests that given a valid packed multi file on disk, we're
|
||||
// able to read it multiple times repeatedly.
|
||||
func TestExtractMulti(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
// First, as prep, we'll create a single chan backup, then pack that
|
||||
// fully into a multi backup.
|
||||
channel, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to gen chan: %v", err)
|
||||
}
|
||||
|
||||
singleBackup := NewSingle(channel, nil)
|
||||
|
||||
var b bytes.Buffer
|
||||
unpackedMulti := Multi{
|
||||
StaticBackups: []Single{singleBackup},
|
||||
}
|
||||
err = unpackedMulti.PackToWriter(&b, keyRing)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to pack to writer: %v", err)
|
||||
}
|
||||
|
||||
packedMulti := PackedMulti(b.Bytes())
|
||||
|
||||
// Finally, we'll make a new temporary file, then write out the packed
|
||||
// multi directly to to it.
|
||||
tempFile, err := ioutil.TempFile("", "")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create temp file: %v", err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
_, err = tempFile.Write(packedMulti)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to write temp file: %v", err)
|
||||
}
|
||||
if err := tempFile.Sync(); err != nil {
|
||||
t.Fatalf("unable to sync temp file: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
fileName string
|
||||
pass bool
|
||||
}{
|
||||
// Main file not read, file name not present.
|
||||
{
|
||||
fileName: "",
|
||||
pass: false,
|
||||
},
|
||||
|
||||
// Main file not read, file name is there, but file doesn't
|
||||
// exist.
|
||||
{
|
||||
fileName: "kek",
|
||||
pass: false,
|
||||
},
|
||||
|
||||
// Main file not read, should be able to read multiple times.
|
||||
{
|
||||
fileName: tempFile.Name(),
|
||||
pass: true,
|
||||
},
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
// First, we'll make our backup file with the specified name.
|
||||
backupFile := NewMultiFile(testCase.fileName)
|
||||
|
||||
// With our file made, we'll now attempt to read out the
|
||||
// multi-file.
|
||||
freshUnpackedMulti, err := backupFile.ExtractMulti(keyRing)
|
||||
switch {
|
||||
// If this is a valid test case, and we failed, then we'll
|
||||
// return an error.
|
||||
case err != nil && testCase.pass:
|
||||
t.Fatalf("#%v, unable to extract file: %v", i, err)
|
||||
|
||||
// If this is an invalid test case, and we passed it, then
|
||||
// we'll return an error.
|
||||
case err == nil && !testCase.pass:
|
||||
t.Fatalf("#%v file extraction should have "+
|
||||
"failed: %v", i, err)
|
||||
}
|
||||
|
||||
if !testCase.pass {
|
||||
continue
|
||||
}
|
||||
|
||||
// We'll now ensure that the unpacked multi we read is
|
||||
// identical to the one we wrote out above.
|
||||
assertMultiEqual(t, &unpackedMulti, freshUnpackedMulti)
|
||||
|
||||
// We should also be able to read the file again, as we have an
|
||||
// existing handle to it.
|
||||
freshUnpackedMulti, err = backupFile.ExtractMulti(keyRing)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unpack multi: %v", err)
|
||||
}
|
||||
|
||||
assertMultiEqual(t, &unpackedMulti, freshUnpackedMulti)
|
||||
}
|
||||
}
|
140
chanbackup/crypto.go
Normal file
140
chanbackup/crypto.go
Normal file
@ -0,0 +1,140 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// TODO(roasbeef): interface in front of?
|
||||
|
||||
// baseEncryptionKeyLoc is the KeyLocator that we'll use to derive the base
|
||||
// encryption key used for encrypting all static channel backups. We use this
|
||||
// to then derive the actual key that we'll use for encryption. We do this
|
||||
// rather than using the raw key, as we assume that we can't obtain the raw
|
||||
// keys, and we don't want to require that the HSM know our target cipher for
|
||||
// encryption.
|
||||
//
|
||||
// TODO(roasbeef): possibly unique encrypt?
|
||||
var baseEncryptionKeyLoc = keychain.KeyLocator{
|
||||
Family: keychain.KeyFamilyStaticBackup,
|
||||
Index: 0,
|
||||
}
|
||||
|
||||
// genEncryptionKey derives the key that we'll use to encrypt all of our static
|
||||
// channel backups. The key itself, is the sha2 of a base key that we get from
|
||||
// the keyring. We derive the key this way as we don't force the HSM (or any
|
||||
// future abstractions) to be able to derive and know of the cipher that we'll
|
||||
// use within our protocol.
|
||||
func genEncryptionKey(keyRing keychain.KeyRing) ([]byte, error) {
|
||||
// key = SHA256(baseKey)
|
||||
baseKey, err := keyRing.DeriveKey(
|
||||
baseEncryptionKeyLoc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encryptionKey := sha256.Sum256(
|
||||
baseKey.PubKey.SerializeCompressed(),
|
||||
)
|
||||
|
||||
// TODO(roasbeef): throw back in ECDH?
|
||||
|
||||
return encryptionKey[:], nil
|
||||
}
|
||||
|
||||
// encryptPayloadToWriter attempts to write the set of bytes contained within
|
||||
// the passed byes.Buffer into the passed io.Writer in an encrypted form. We
|
||||
// use a 24-byte chachapoly AEAD instance with a randomized nonce that's
|
||||
// pre-pended to the final payload and used as associated data in the AEAD. We
|
||||
// use the passed keyRing to generate the encryption key, see genEncryptionKey
|
||||
// for further details.
|
||||
func encryptPayloadToWriter(payload bytes.Buffer, w io.Writer,
|
||||
keyRing keychain.KeyRing) error {
|
||||
|
||||
// First, we'll derive the key that we'll use to encrypt the payload
|
||||
// for safe storage without giving away the details of any of our
|
||||
// channels. The final operation is:
|
||||
//
|
||||
// key = SHA256(baseKey)
|
||||
encryptionKey, err := genEncryptionKey(keyRing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Before encryption, we'll initialize our cipher with the target
|
||||
// encryption key, and also read out our random 24-byte nonce we use
|
||||
// for encryption. Note that we use NewX, not New, as the latter
|
||||
// version requires a 12-byte nonce, not a 24-byte nonce.
|
||||
cipher, err := chacha20poly1305.NewX(encryptionKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var nonce [chacha20poly1305.NonceSizeX]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, we encrypted the final payload, and write out our
|
||||
// ciphertext with nonce pre-pended.
|
||||
ciphertext := cipher.Seal(nil, nonce[:], payload.Bytes(), nonce[:])
|
||||
|
||||
if _, err := w.Write(nonce[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write(ciphertext); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptPayloadFromReader attempts to decrypt the encrypted bytes within the
|
||||
// passed io.Reader instance using the key derived from the passed keyRing. For
|
||||
// further details regarding the key derivation protocol, see the
|
||||
// genEncryptionKey method.
|
||||
func decryptPayloadFromReader(payload io.Reader,
|
||||
keyRing keychain.KeyRing) ([]byte, error) {
|
||||
|
||||
// First, we'll re-generate the encryption key that we use for all the
|
||||
// SCBs.
|
||||
encryptionKey, err := genEncryptionKey(keyRing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Next, we'll read out the entire blob as we need to isolate the nonce
|
||||
// from the rest of the ciphertext.
|
||||
packedBackup, err := ioutil.ReadAll(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(packedBackup) < chacha20poly1305.NonceSizeX {
|
||||
return nil, fmt.Errorf("payload size too small, must be at "+
|
||||
"least %v bytes", chacha20poly1305.NonceSizeX)
|
||||
}
|
||||
|
||||
nonce := packedBackup[:chacha20poly1305.NonceSizeX]
|
||||
ciphertext := packedBackup[chacha20poly1305.NonceSizeX:]
|
||||
|
||||
// Now that we have the cipher text and the nonce separated, we can go
|
||||
// ahead and decrypt the final blob so we can properly serialized the
|
||||
// SCB.
|
||||
cipher, err := chacha20poly1305.NewX(encryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plaintext, err := cipher.Open(nil, nonce, ciphertext, nonce)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
156
chanbackup/crypto_test.go
Normal file
156
chanbackup/crypto_test.go
Normal file
@ -0,0 +1,156 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
)
|
||||
|
||||
var (
|
||||
testWalletPrivKey = []byte{
|
||||
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
|
||||
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
|
||||
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
|
||||
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
|
||||
}
|
||||
)
|
||||
|
||||
type mockKeyRing struct {
|
||||
fail bool
|
||||
}
|
||||
|
||||
func (m *mockKeyRing) DeriveNextKey(keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
|
||||
return keychain.KeyDescriptor{}, nil
|
||||
}
|
||||
func (m *mockKeyRing) DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) {
|
||||
if m.fail {
|
||||
return keychain.KeyDescriptor{}, fmt.Errorf("fail")
|
||||
}
|
||||
|
||||
_, pub := btcec.PrivKeyFromBytes(btcec.S256(), testWalletPrivKey)
|
||||
return keychain.KeyDescriptor{
|
||||
PubKey: pub,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestEncryptDecryptPayload tests that given a static key, we're able to
|
||||
// properly decrypt and encrypted payload. We also test that we'll reject a
|
||||
// ciphertext that has been modified.
|
||||
func TestEncryptDecryptPayload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payloadCases := []struct {
|
||||
// plaintext is the string that we'll be encrypting.
|
||||
plaintext []byte
|
||||
|
||||
// mutator allows a test case to modify the ciphertext before
|
||||
// we attempt to decrypt it.
|
||||
mutator func(*[]byte)
|
||||
|
||||
// valid indicates if this test should pass or fail.
|
||||
valid bool
|
||||
}{
|
||||
// Proper payload, should decrypt.
|
||||
{
|
||||
plaintext: []byte("payload test plain text"),
|
||||
mutator: nil,
|
||||
valid: true,
|
||||
},
|
||||
|
||||
// Mutator modifies cipher text, shouldn't decrypt.
|
||||
{
|
||||
plaintext: []byte("payload test plain text"),
|
||||
mutator: func(p *[]byte) {
|
||||
// Flip a byte in the payload to render it invalid.
|
||||
(*p)[0] ^= 1
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
|
||||
// Cipher text is too small, shouldn't decrypt.
|
||||
{
|
||||
plaintext: []byte("payload test plain text"),
|
||||
mutator: func(p *[]byte) {
|
||||
// Modify the cipher text to be zero length.
|
||||
*p = []byte{}
|
||||
},
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
for i, payloadCase := range payloadCases {
|
||||
var cipherBuffer bytes.Buffer
|
||||
|
||||
// First, we'll encrypt the passed payload with our scheme.
|
||||
payloadReader := bytes.NewBuffer(payloadCase.plaintext)
|
||||
err := encryptPayloadToWriter(
|
||||
*payloadReader, &cipherBuffer, keyRing,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable encrypt paylaod: %v", err)
|
||||
}
|
||||
|
||||
// If we have a mutator, then we'll wrong the mutator over the
|
||||
// cipher text, then reset the main buffer and re-write the new
|
||||
// cipher text.
|
||||
if payloadCase.mutator != nil {
|
||||
cipherText := cipherBuffer.Bytes()
|
||||
|
||||
payloadCase.mutator(&cipherText)
|
||||
|
||||
cipherBuffer.Reset()
|
||||
cipherBuffer.Write(cipherText)
|
||||
}
|
||||
|
||||
plaintext, err := decryptPayloadFromReader(&cipherBuffer, keyRing)
|
||||
|
||||
switch {
|
||||
// If this was meant to be a valid decryption, but we failed,
|
||||
// then we'll return an error.
|
||||
case err != nil && payloadCase.valid:
|
||||
t.Fatalf("unable to decrypt valid payload case %v", i)
|
||||
|
||||
// If this was meant to be an invalid decryption, and we didn't
|
||||
// fail, then we'll return an error.
|
||||
case err == nil && !payloadCase.valid:
|
||||
t.Fatalf("payload was invalid yet was able to decrypt")
|
||||
}
|
||||
|
||||
// Only if this case was mean to be valid will we ensure the
|
||||
// resulting decrypted plaintext matches the original input.
|
||||
if payloadCase.valid &&
|
||||
!bytes.Equal(plaintext, payloadCase.plaintext) {
|
||||
t.Fatalf("#%v: expected %v, got %v: ", i,
|
||||
payloadCase.plaintext, plaintext)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidKeyEncryption tests that encryption fails if we're unable to
|
||||
// obtain a valid key.
|
||||
func TestInvalidKeyEncryption(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var b bytes.Buffer
|
||||
err := encryptPayloadToWriter(b, &b, &mockKeyRing{true})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error due to fail key gen")
|
||||
}
|
||||
}
|
||||
|
||||
// TestInvalidKeyDecrytion tests that decryption fails if we're unable to
|
||||
// obtain a valid key.
|
||||
func TestInvalidKeyDecrytion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var b bytes.Buffer
|
||||
_, err := decryptPayloadFromReader(&b, &mockKeyRing{true})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error due to fail key gen")
|
||||
}
|
||||
}
|
45
chanbackup/log.go
Normal file
45
chanbackup/log.go
Normal file
@ -0,0 +1,45 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
"github.com/lightningnetwork/lnd/build"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized with no output filters. This
|
||||
// means the package will not perform any logging by default until the caller
|
||||
// requests it.
|
||||
var log btclog.Logger
|
||||
|
||||
// The default amount of logging is none.
|
||||
func init() {
|
||||
UseLogger(build.NewSubLogger("CHBU", nil))
|
||||
}
|
||||
|
||||
// DisableLog disables all library log output. Logging output is disabled
|
||||
// by default until UseLogger is called.
|
||||
func DisableLog() {
|
||||
UseLogger(btclog.Disabled)
|
||||
}
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
// This should be used in preference to SetLogWriter if the caller is also
|
||||
// using btclog.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
||||
|
||||
// logClosure is used to provide a closure over expensive logging operations so
|
||||
// don't have to be performed when the logging level doesn't warrant it.
|
||||
type logClosure func() string
|
||||
|
||||
// String invokes the underlying function and returns the result.
|
||||
func (c logClosure) String() string {
|
||||
return c()
|
||||
}
|
||||
|
||||
// newLogClosure returns a new closure over a function that returns a string
|
||||
// which itself provides a Stringer interface so that it can be used with the
|
||||
// logging system.
|
||||
func newLogClosure(c func() string) logClosure {
|
||||
return logClosure(c)
|
||||
}
|
176
chanbackup/multi.go
Normal file
176
chanbackup/multi.go
Normal file
@ -0,0 +1,176 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// MultiBackupVersion denotes the version of the multi channel static channel
|
||||
// backup. Based on this version, we know how to encode/decode packed/unpacked
|
||||
// versions of multi backups.
|
||||
type MultiBackupVersion byte
|
||||
|
||||
const (
|
||||
// DefaultMultiVersion is the default version of the multi channel
|
||||
// backup. The serialized format for this version is simply: version ||
|
||||
// numBackups || SCBs...
|
||||
DefaultMultiVersion = 0
|
||||
)
|
||||
|
||||
// Multi is a form of static channel backup that is amenable to being
|
||||
// serialized in a single file. Rather than a series of ciphertexts, a
|
||||
// multi-chan backup is a single ciphertext of all static channel backups
|
||||
// concatenated. This form factor gives users a single blob that they can use
|
||||
// to safely copy/obtain at anytime to backup their channels.
|
||||
type Multi struct {
|
||||
// Version is the version that should be observed when attempting to
|
||||
// pack the multi backup.
|
||||
Version MultiBackupVersion
|
||||
|
||||
// StaticBackups is the set of single channel backups that this multi
|
||||
// backup is comprised of.
|
||||
StaticBackups []Single
|
||||
}
|
||||
|
||||
// PackToWriter packs (encrypts+serializes) the target set of static channel
|
||||
// backups into a single AEAD ciphertext into the passed io.Writer. This is the
|
||||
// opposite of UnpackFromReader. The plaintext form of a multi-chan backup is
|
||||
// the following: a 4 byte integer denoting the number of serialized static
|
||||
// channel backups serialized, a series of serialized static channel backups
|
||||
// concatenated. To pack this payload, we then apply our chacha20 AEAD to the
|
||||
// entire payload, using the 24-byte nonce as associated data.
|
||||
func (m Multi) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error {
|
||||
// The only version that we know how to pack atm is version 0. Attempts
|
||||
// to pack any other version will result in an error.
|
||||
switch m.Version {
|
||||
case DefaultMultiVersion:
|
||||
break
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unable to pack unknown multi-version "+
|
||||
"of %v", m.Version)
|
||||
}
|
||||
|
||||
var multiBackupBuffer bytes.Buffer
|
||||
|
||||
// First, we'll write out the version of this multi channel baackup.
|
||||
err := lnwire.WriteElements(&multiBackupBuffer, byte(m.Version))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now that we've written out the version of this multi-pack format,
|
||||
// we'll now write the total number of backups to expect after this
|
||||
// point.
|
||||
numBackups := uint32(len(m.StaticBackups))
|
||||
err = lnwire.WriteElements(&multiBackupBuffer, numBackups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Next, we'll serialize the raw plaintext version of each of the
|
||||
// backup into the intermediate buffer.
|
||||
for _, chanBackup := range m.StaticBackups {
|
||||
err := chanBackup.Serialize(&multiBackupBuffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to serialize backup "+
|
||||
"for %v: %v", chanBackup.FundingOutpoint, err)
|
||||
}
|
||||
}
|
||||
|
||||
// With the plaintext multi backup assembled, we'll now encrypt it
|
||||
// directly to the passed writer.
|
||||
return encryptPayloadToWriter(multiBackupBuffer, w, keyRing)
|
||||
}
|
||||
|
||||
// UnpackFromReader attempts to unpack (decrypt+deserialize) a packed
|
||||
// multi-chan backup form the passed io.Reader. If we're unable to decrypt the
|
||||
// any portion of the multi-chan backup, an error will be returned.
|
||||
func (m *Multi) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error {
|
||||
// We'll attempt to read the entire packed backup, and also decrypt it
|
||||
// using the passed key ring which is expected to be able to derive the
|
||||
// encryption keys.
|
||||
plaintextBackup, err := decryptPayloadFromReader(r, keyRing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
backupReader := bytes.NewReader(plaintextBackup)
|
||||
|
||||
// Now that we've decrypted the payload successfully, we can parse out
|
||||
// each of the individual static channel backups.
|
||||
|
||||
// First, we'll need to read the version of this multi-back up so we
|
||||
// can know how to unpack each of the individual SCB's.
|
||||
var multiVersion byte
|
||||
err = lnwire.ReadElements(backupReader, &multiVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Version = MultiBackupVersion(multiVersion)
|
||||
switch m.Version {
|
||||
|
||||
// The default version is simply a set of serialized SCB's with the
|
||||
// number of total SCB's prepended to the front of the byte slice.
|
||||
case DefaultMultiVersion:
|
||||
// First, we'll need to read out the total number of backups
|
||||
// that've been serialized into this multi-chan backup. Each
|
||||
// backup is the same size, so we can continue until we've
|
||||
// parsed out everything.
|
||||
var numBackups uint32
|
||||
err = lnwire.ReadElements(backupReader, &numBackups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We'll continue to parse out each backup until we've read all
|
||||
// that was indicated from the length prefix.
|
||||
for ; numBackups != 0; numBackups-- {
|
||||
// Attempt to parse out the net static channel backup,
|
||||
// if it's been malformed, then we'll return with an
|
||||
// error
|
||||
var chanBackup Single
|
||||
err := chanBackup.Deserialize(backupReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Collect the next valid chan backup into the main
|
||||
// multi backup slice.
|
||||
m.StaticBackups = append(m.StaticBackups, chanBackup)
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unable to unpack unknown multi-version "+
|
||||
"of %v", multiVersion)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(roasbeef): new key ring interface?
|
||||
// * just returns key given params?
|
||||
|
||||
// PackedMulti represents a raw fully packed (serialized+encrypted)
|
||||
// multi-channel static channel backup.
|
||||
type PackedMulti []byte
|
||||
|
||||
// Unpack attempts to unpack (decrypt+desrialize) the target packed
|
||||
// multi-channel back up. If we're unable to fully unpack this back, then an
|
||||
// error will be returned.
|
||||
func (p *PackedMulti) Unpack(keyRing keychain.KeyRing) (*Multi, error) {
|
||||
var m Multi
|
||||
|
||||
packedReader := bytes.NewReader(*p)
|
||||
if err := m.UnpackFromReader(packedReader, keyRing); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
// TODO(roasbsef): fuzz parsing
|
159
chanbackup/multi_test.go
Normal file
159
chanbackup/multi_test.go
Normal file
@ -0,0 +1,159 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMultiPackUnpack...
|
||||
func TestMultiPackUnpack(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var multi Multi
|
||||
numSingles := 10
|
||||
originalSingles := make([]Single, 0, numSingles)
|
||||
for i := 0; i < numSingles; i++ {
|
||||
channel, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to gen channel: %v", err)
|
||||
}
|
||||
|
||||
single := NewSingle(channel, []net.Addr{addr1, addr2})
|
||||
|
||||
originalSingles = append(originalSingles, single)
|
||||
multi.StaticBackups = append(multi.StaticBackups, single)
|
||||
}
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
versionTestCases := []struct {
|
||||
// version is the pack/unpack version that we should use to
|
||||
// decode/encode the final SCB.
|
||||
version MultiBackupVersion
|
||||
|
||||
// valid tests us if this test case should pass or not.
|
||||
valid bool
|
||||
}{
|
||||
// The default version, should pack/unpack with no problem.
|
||||
{
|
||||
version: DefaultSingleVersion,
|
||||
valid: true,
|
||||
},
|
||||
|
||||
// A non-default version, atm this should result in a failure.
|
||||
{
|
||||
version: 99,
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
for i, versionCase := range versionTestCases {
|
||||
multi.Version = versionCase.version
|
||||
|
||||
var b bytes.Buffer
|
||||
err := multi.PackToWriter(&b, keyRing)
|
||||
switch {
|
||||
// If this is a valid test case, and we failed, then we'll
|
||||
// return an error.
|
||||
case err != nil && versionCase.valid:
|
||||
t.Fatalf("#%v, unable to pack multi: %v", i, err)
|
||||
|
||||
// If this is an invalid test case, and we passed it, then
|
||||
// we'll return an error.
|
||||
case err == nil && !versionCase.valid:
|
||||
t.Fatalf("#%v got nil error for invalid pack: %v",
|
||||
i, err)
|
||||
}
|
||||
|
||||
// If this is a valid test case, then we'll continue to ensure
|
||||
// we can unpack it, and also that if we mutate the packed
|
||||
// version, then we trigger an error.
|
||||
if versionCase.valid {
|
||||
var unpackedMulti Multi
|
||||
err = unpackedMulti.UnpackFromReader(&b, keyRing)
|
||||
if err != nil {
|
||||
t.Fatalf("#%v unable to unpack multi: %v",
|
||||
i, err)
|
||||
}
|
||||
|
||||
// First, we'll ensure that the unpacked version of the
|
||||
// packed multi is the same as the original set.
|
||||
if len(originalSingles) !=
|
||||
len(unpackedMulti.StaticBackups) {
|
||||
t.Fatalf("expected %v singles, got %v",
|
||||
len(originalSingles),
|
||||
len(unpackedMulti.StaticBackups))
|
||||
}
|
||||
for i := 0; i < numSingles; i++ {
|
||||
assertSingleEqual(
|
||||
t, originalSingles[i],
|
||||
unpackedMulti.StaticBackups[i],
|
||||
)
|
||||
}
|
||||
|
||||
// Next, we'll make a fake packed multi, it'll have an
|
||||
// unknown version relative to what's implemented atm.
|
||||
var fakePackedMulti bytes.Buffer
|
||||
fakeRawMulti := bytes.NewBuffer(
|
||||
bytes.Repeat([]byte{99}, 20),
|
||||
)
|
||||
err := encryptPayloadToWriter(
|
||||
*fakeRawMulti, &fakePackedMulti, keyRing,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to pack fake multi; %v", err)
|
||||
}
|
||||
|
||||
// We should reject this fake multi as it contains an
|
||||
// unknown version.
|
||||
err = unpackedMulti.UnpackFromReader(
|
||||
&fakePackedMulti, keyRing,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("#%v unpack with unknown version "+
|
||||
"should have failed", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestPackedMultiUnpack tests that we're able to properly unpack a typed
|
||||
// packed multi.
|
||||
func TestPackedMultiUnpack(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
// First, we'll make a new unpacked multi with a random channel.
|
||||
testChannel, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to gen random channel: %v", err)
|
||||
}
|
||||
var multi Multi
|
||||
multi.StaticBackups = append(
|
||||
multi.StaticBackups, NewSingle(testChannel, nil),
|
||||
)
|
||||
|
||||
// Now that we have our multi, we'll pack it into a new buffer.
|
||||
var b bytes.Buffer
|
||||
if err := multi.PackToWriter(&b, keyRing); err != nil {
|
||||
t.Fatalf("unable to pack multi: %v", err)
|
||||
}
|
||||
|
||||
// We should be able to properly unpack this typed packed multi.
|
||||
packedMulti := PackedMulti(b.Bytes())
|
||||
unpackedMulti, err := packedMulti.Unpack(keyRing)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unpack multi: %v", err)
|
||||
}
|
||||
|
||||
// Finally, the versions should match, and the unpacked singles also
|
||||
// identical.
|
||||
if multi.Version != unpackedMulti.Version {
|
||||
t.Fatalf("version mismatch: expected %v got %v",
|
||||
multi.Version, unpackedMulti.Version)
|
||||
}
|
||||
assertSingleEqual(
|
||||
t, multi.StaticBackups[0], unpackedMulti.StaticBackups[0],
|
||||
)
|
||||
}
|
247
chanbackup/pubsub.go
Normal file
247
chanbackup/pubsub.go
Normal file
@ -0,0 +1,247 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
)
|
||||
|
||||
// Swapper is an interface that allows the chanbackup.SubSwapper to update the
|
||||
// main multi backup location once it learns of new channels or that prior
|
||||
// channels have been closed.
|
||||
type Swapper interface {
|
||||
// UpdateAndSwap attempts to atomically update the main multi back up
|
||||
// file location with the new fully packed multi-channel backup.
|
||||
UpdateAndSwap(newBackup PackedMulti) error
|
||||
}
|
||||
|
||||
// ChannelWithAddrs bundles an open channel along with all the addresses for
|
||||
// the channel peer.
|
||||
//
|
||||
// TODO(roasbeef): use channel shell instead?
|
||||
type ChannelWithAddrs struct {
|
||||
*channeldb.OpenChannel
|
||||
|
||||
// Addrs is the set of addresses that we can use to reach the target
|
||||
// peer.
|
||||
Addrs []net.Addr
|
||||
}
|
||||
|
||||
// ChannelEvent packages a new update of new channels since subscription, and
|
||||
// channels that have been opened since prior channel event.
|
||||
type ChannelEvent struct {
|
||||
// ClosedChans are the set of channels that have been closed since the
|
||||
// last event.
|
||||
ClosedChans []wire.OutPoint
|
||||
|
||||
// NewChans is the set of channels that have been opened since the last
|
||||
// event.
|
||||
NewChans []ChannelWithAddrs
|
||||
}
|
||||
|
||||
// ChannelSubscription represents an intent to be notified of any updates to
|
||||
// the primary channel state.
|
||||
type ChannelSubscription struct {
|
||||
// ChanUpdates is a read-only channel that will be sent upon once the
|
||||
// primary channel state is updated.
|
||||
ChanUpdates <-chan ChannelEvent
|
||||
|
||||
// Cancel is a closure that allows the caller to cancel their
|
||||
// subscription and free up any resources allocated.
|
||||
Cancel func()
|
||||
}
|
||||
|
||||
// ChannelNotifier represents a system that allows the chanbackup.SubSwapper to
|
||||
// be notified of any changes to the primary channel state.
|
||||
type ChannelNotifier interface {
|
||||
// SubscribeChans requests a new channel subscription relative to the
|
||||
// initial set of known channels. We use the knownChans as a
|
||||
// synchronization point to ensure that the chanbackup.SubSwapper does
|
||||
// not miss any channel open or close events in the period between when
|
||||
// it's created, and when it requests the channel subscription.
|
||||
SubscribeChans(map[wire.OutPoint]struct{}) (*ChannelSubscription, error)
|
||||
}
|
||||
|
||||
// SubSwapper subscribes to new updates to the open channel state, and then
|
||||
// swaps out the on-disk channel backup state in response. This sub-system
|
||||
// that will ensure that the multi chan backup file on disk will always be
|
||||
// updated with the latest channel back up state. We'll receive new
|
||||
// opened/closed channels from the ChannelNotifier, then use the Swapper to
|
||||
// update the file state on disk with the new set of open channels. This can
|
||||
// be used to implement a system that always keeps the multi-chan backup file
|
||||
// on disk in a consistent state for safety purposes.
|
||||
//
|
||||
// TODO(roasbeef): better name lol
|
||||
type SubSwapper struct {
|
||||
started uint32
|
||||
stopped uint32
|
||||
|
||||
// backupState are the set of SCBs for all open channels we know of.
|
||||
backupState map[wire.OutPoint]Single
|
||||
|
||||
// chanEvents is an active subscription to receive new channel state
|
||||
// over.
|
||||
chanEvents *ChannelSubscription
|
||||
|
||||
// keyRing is the main key ring that will allow us to pack the new
|
||||
// multi backup.
|
||||
keyRing keychain.KeyRing
|
||||
|
||||
Swapper
|
||||
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewSubSwapper creates a new instance of the SubSwapper given the starting
|
||||
// set of channels, and the required interfaces to be notified of new channel
|
||||
// updates, pack a multi backup, and swap the current best backup from its
|
||||
// storage location.
|
||||
func NewSubSwapper(startingChans []Single, chanNotifier ChannelNotifier,
|
||||
keyRing keychain.KeyRing, backupSwapper Swapper) (*SubSwapper, error) {
|
||||
|
||||
// First, we'll subscribe to the latest set of channel updates given
|
||||
// the set of channels we already know of.
|
||||
knownChans := make(map[wire.OutPoint]struct{})
|
||||
for _, chanBackup := range startingChans {
|
||||
knownChans[chanBackup.FundingOutpoint] = struct{}{}
|
||||
}
|
||||
chanEvents, err := chanNotifier.SubscribeChans(knownChans)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Next, we'll construct our own backup state so we can add/remove
|
||||
// channels that have been opened and closed.
|
||||
backupState := make(map[wire.OutPoint]Single)
|
||||
for _, chanBackup := range startingChans {
|
||||
backupState[chanBackup.FundingOutpoint] = chanBackup
|
||||
}
|
||||
|
||||
return &SubSwapper{
|
||||
backupState: backupState,
|
||||
chanEvents: chanEvents,
|
||||
keyRing: keyRing,
|
||||
Swapper: backupSwapper,
|
||||
quit: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start starts the chanbackup.SubSwapper.
|
||||
func (s *SubSwapper) Start() error {
|
||||
if !atomic.CompareAndSwapUint32(&s.started, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("Starting chanbackup.SubSwapper")
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.backupUpdater()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop signals the SubSwapper to being a graceful shutdown.
|
||||
func (s *SubSwapper) Stop() error {
|
||||
if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("Stopping chanbackup.SubSwapper")
|
||||
|
||||
close(s.quit)
|
||||
s.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// backupFileUpdater is the primary goroutine of the SubSwapper which is
|
||||
// responsible for listening for changes to the channel, and updating the
|
||||
// persistent multi backup state with a new packed multi of the latest channel
|
||||
// state.
|
||||
func (s *SubSwapper) backupUpdater() {
|
||||
// Ensure that once we exit, we'll cancel our active channel
|
||||
// subscription.
|
||||
defer s.chanEvents.Cancel()
|
||||
defer s.wg.Done()
|
||||
|
||||
log.Debugf("SubSwapper's backupUpdater is active!")
|
||||
|
||||
for {
|
||||
select {
|
||||
// The channel state has been modified! We'll evaluate all
|
||||
// changes, and swap out the old packed multi with a new one
|
||||
// with the latest channel state.
|
||||
case chanUpdate := <-s.chanEvents.ChanUpdates:
|
||||
oldStateSize := len(s.backupState)
|
||||
|
||||
// For all new open channels, we'll create a new SCB
|
||||
// given the required information.
|
||||
for _, newChan := range chanUpdate.NewChans {
|
||||
log.Debugf("Adding chanenl %v to backup state",
|
||||
newChan.FundingOutpoint)
|
||||
|
||||
s.backupState[newChan.FundingOutpoint] = NewSingle(
|
||||
newChan.OpenChannel, newChan.Addrs,
|
||||
)
|
||||
}
|
||||
|
||||
// For all closed channels, we'll remove the prior
|
||||
// backup state.
|
||||
for _, closedChan := range chanUpdate.ClosedChans {
|
||||
log.Debugf("Removing channel %v from backup "+
|
||||
"state", newLogClosure(func() string {
|
||||
return closedChan.String()
|
||||
}))
|
||||
|
||||
delete(s.backupState, closedChan)
|
||||
}
|
||||
|
||||
newStateSize := len(s.backupState)
|
||||
|
||||
// With our updated channel state obtained, we'll
|
||||
// create a new multi from our series of singles.
|
||||
var newMulti Multi
|
||||
for _, backup := range s.backupState {
|
||||
newMulti.StaticBackups = append(
|
||||
newMulti.StaticBackups, backup,
|
||||
)
|
||||
}
|
||||
|
||||
// Now that our multi has been assembled, we'll attempt
|
||||
// to pack (encrypt+encode) the new channel state to
|
||||
// our target reader.
|
||||
var b bytes.Buffer
|
||||
err := newMulti.PackToWriter(&b, s.keyRing)
|
||||
if err != nil {
|
||||
log.Errorf("unable to pack multi backup: %v",
|
||||
err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("Updating on-disk multi SCB backup: "+
|
||||
"num_old_chans=%v, num_new_chans=%v",
|
||||
oldStateSize, newStateSize)
|
||||
|
||||
// Finally, we'll swap out the old backup for this new
|
||||
// one in a single atomic step.
|
||||
err = s.Swapper.UpdateAndSwap(
|
||||
PackedMulti(b.Bytes()),
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("unable to update multi "+
|
||||
"backup: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Exit at once if a quit signal is detected.
|
||||
case <-s.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
234
chanbackup/pubsub_test.go
Normal file
234
chanbackup/pubsub_test.go
Normal file
@ -0,0 +1,234 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
)
|
||||
|
||||
type mockSwapper struct {
|
||||
fail bool
|
||||
|
||||
swaps chan PackedMulti
|
||||
}
|
||||
|
||||
func newMockSwapper() *mockSwapper {
|
||||
return &mockSwapper{
|
||||
swaps: make(chan PackedMulti),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockSwapper) UpdateAndSwap(newBackup PackedMulti) error {
|
||||
if m.fail {
|
||||
return fmt.Errorf("fail")
|
||||
}
|
||||
|
||||
m.swaps <- newBackup
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockChannelNotifier struct {
|
||||
fail bool
|
||||
|
||||
chanEvents chan ChannelEvent
|
||||
}
|
||||
|
||||
func newMockChannelNotifier() *mockChannelNotifier {
|
||||
return &mockChannelNotifier{
|
||||
chanEvents: make(chan ChannelEvent),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) (
|
||||
*ChannelSubscription, error) {
|
||||
|
||||
if m.fail {
|
||||
return nil, fmt.Errorf("fail")
|
||||
}
|
||||
|
||||
return &ChannelSubscription{
|
||||
ChanUpdates: m.chanEvents,
|
||||
Cancel: func() {
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestNewSubSwapperSubscribeFail tests that if we're unable to obtain a
|
||||
// channel subscription, then the entire sub-swapper will fail to start.
|
||||
func TestNewSubSwapperSubscribeFail(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
var swapper mockSwapper
|
||||
chanNotifier := mockChannelNotifier{
|
||||
fail: true,
|
||||
}
|
||||
|
||||
_, err := NewSubSwapper(nil, &chanNotifier, keyRing, &swapper)
|
||||
if err == nil {
|
||||
t.Fatalf("expected fail due to lack of subscription")
|
||||
}
|
||||
}
|
||||
|
||||
func assertExpectedBackupSwap(t *testing.T, swapper *mockSwapper,
|
||||
subSwapper *SubSwapper, keyRing keychain.KeyRing,
|
||||
expectedChanSet map[wire.OutPoint]Single) {
|
||||
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case newPackedMulti := <-swapper.swaps:
|
||||
// If we unpack the new multi, then we should find all the old
|
||||
// channels, and also the new channel included and any deleted
|
||||
// channel omitted..
|
||||
newMulti, err := newPackedMulti.Unpack(keyRing)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unpack multi: %v", err)
|
||||
}
|
||||
|
||||
// Ensure that once unpacked, the current backup has the
|
||||
// expected number of Singles.
|
||||
if len(newMulti.StaticBackups) != len(expectedChanSet) {
|
||||
t.Fatalf("new backup wasn't included: expected %v "+
|
||||
"backups have %v", len(expectedChanSet),
|
||||
len(newMulti.StaticBackups))
|
||||
}
|
||||
|
||||
// We should also find all the old and new channels in this new
|
||||
// backup.
|
||||
for _, backup := range newMulti.StaticBackups {
|
||||
_, ok := expectedChanSet[backup.FundingOutpoint]
|
||||
if !ok {
|
||||
t.Fatalf("didn't find backup in original set: %v",
|
||||
backup.FundingOutpoint)
|
||||
}
|
||||
}
|
||||
|
||||
// The internal state of the sub-swapper should also be one
|
||||
// larger.
|
||||
if !reflect.DeepEqual(expectedChanSet, subSwapper.backupState) {
|
||||
t.Fatalf("backup set doesn't match: expected %v got %v",
|
||||
spew.Sdump(expectedChanSet),
|
||||
spew.Sdump(subSwapper.backupState))
|
||||
}
|
||||
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Fatalf("update swapper didn't swap out multi")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubSwapperIdempotentStartStop tests that calling the Start/Stop methods
|
||||
// multiple time is permitted.
|
||||
func TestSubSwapperIdempotentStartStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
var (
|
||||
swapper mockSwapper
|
||||
chanNotifier mockChannelNotifier
|
||||
)
|
||||
|
||||
subSwapper, err := NewSubSwapper(nil, &chanNotifier, keyRing, &swapper)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to init subSwapper: %v", err)
|
||||
}
|
||||
|
||||
subSwapper.Start()
|
||||
subSwapper.Start()
|
||||
|
||||
subSwapper.Stop()
|
||||
subSwapper.Stop()
|
||||
}
|
||||
|
||||
// TestSubSwapperUpdater tests that the SubSwapper will properly swap out
|
||||
// new/old channels within the channel set, and notify the swapper to update
|
||||
// the master multi file backup.
|
||||
func TestSubSwapperUpdater(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
chanNotifier := newMockChannelNotifier()
|
||||
swapper := newMockSwapper()
|
||||
|
||||
// First, we'll start out by creating a channels set for the initial
|
||||
// set of channels known to the sub-swapper.
|
||||
const numStartingChans = 3
|
||||
initialChanSet := make([]Single, 0, numStartingChans)
|
||||
backupSet := make(map[wire.OutPoint]Single)
|
||||
for i := 0; i < numStartingChans; i++ {
|
||||
channel, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make test chan: %v", err)
|
||||
}
|
||||
|
||||
single := NewSingle(channel, nil)
|
||||
|
||||
backupSet[channel.FundingOutpoint] = single
|
||||
initialChanSet = append(initialChanSet, single)
|
||||
}
|
||||
|
||||
// With our channel set created, we'll make a fresh sub swapper
|
||||
// instance to begin our test.
|
||||
subSwapper, err := NewSubSwapper(
|
||||
initialChanSet, chanNotifier, keyRing, swapper,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make swapper: %v", err)
|
||||
}
|
||||
if err := subSwapper.Start(); err != nil {
|
||||
t.Fatalf("unable to start sub swapper: %v", err)
|
||||
}
|
||||
defer subSwapper.Stop()
|
||||
|
||||
// Now that the sub-swapper is active, we'll notify to add a brand new
|
||||
// channel to the channel state.
|
||||
newChannel, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create new chan: %v", err)
|
||||
}
|
||||
|
||||
// With the new channel created, we'll send a new update to the main
|
||||
// goroutine telling it about this new channel.
|
||||
select {
|
||||
case chanNotifier.chanEvents <- ChannelEvent{
|
||||
NewChans: []ChannelWithAddrs{
|
||||
{
|
||||
OpenChannel: newChannel,
|
||||
},
|
||||
},
|
||||
}:
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Fatalf("update swapper didn't read new channel: %v", err)
|
||||
}
|
||||
|
||||
backupSet[newChannel.FundingOutpoint] = NewSingle(newChannel, nil)
|
||||
|
||||
// At this point, the sub-swapper should now have packed a new multi,
|
||||
// and then sent it to the swapper so the back up can be updated.
|
||||
assertExpectedBackupSwap(t, swapper, subSwapper, keyRing, backupSet)
|
||||
|
||||
// We'll now trigger an update to remove an existing channel.
|
||||
chanToDelete := initialChanSet[0].FundingOutpoint
|
||||
select {
|
||||
case chanNotifier.chanEvents <- ChannelEvent{
|
||||
ClosedChans: []wire.OutPoint{chanToDelete},
|
||||
}:
|
||||
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Fatalf("update swapper didn't read new channel: %v", err)
|
||||
}
|
||||
|
||||
delete(backupSet, chanToDelete)
|
||||
|
||||
// Verify that the new set of backups, now has one less after the
|
||||
// sub-swapper switches the new set with the old.
|
||||
assertExpectedBackupSwap(t, swapper, subSwapper, keyRing, backupSet)
|
||||
}
|
114
chanbackup/recover.go
Normal file
114
chanbackup/recover.go
Normal file
@ -0,0 +1,114 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
)
|
||||
|
||||
// ChannelRestorer is an interface that allows the Recover method to map the
|
||||
// set of single channel backups into a set of "channel shells" and store these
|
||||
// persistently on disk. The channel shell should contain all the information
|
||||
// needed to execute the data loss recovery protocol once the channel peer is
|
||||
// connected to.
|
||||
type ChannelRestorer interface {
|
||||
// RestoreChansFromSingles attempts to map the set of single channel
|
||||
// backups to channel shells that will be stored persistently. Once
|
||||
// these shells have been stored on disk, we'll be able to connect to
|
||||
// the channel peer an execute the data loss recovery protocol.
|
||||
RestoreChansFromSingles(...Single) error
|
||||
}
|
||||
|
||||
// PeerConnector is an interface that allows the Recover method to connect to
|
||||
// the target node given the set of possible addresses.
|
||||
type PeerConnector interface {
|
||||
// ConnectPeer attempts to connect to the target node at the set of
|
||||
// available addresses. Once this method returns with a non-nil error,
|
||||
// the connector should attempt to persistently connect to the target
|
||||
// peer in the background as a persistent attempt.
|
||||
ConnectPeer(node *btcec.PublicKey, addrs []net.Addr) error
|
||||
}
|
||||
|
||||
// Recover attempts to recover the static channel state from a set of static
|
||||
// channel backups. If successfully, the database will be populated with a
|
||||
// series of "shell" channels. These "shell" channels cannot be used to operate
|
||||
// the channel as normal, but instead are meant to be used to enter the data
|
||||
// loss recovery phase, and recover the settled funds within
|
||||
// the channel. In addition a LinkNode will be created for each new peer as
|
||||
// well, in order to expose the addressing information required to locate to
|
||||
// and connect to each peer in order to initiate the recovery protocol.
|
||||
func Recover(backups []Single, restorer ChannelRestorer,
|
||||
peerConnector PeerConnector) error {
|
||||
|
||||
for _, backup := range backups {
|
||||
log.Infof("Restoring ChannelPoint(%v) to disk: ",
|
||||
backup.FundingOutpoint)
|
||||
|
||||
err := restorer.RestoreChansFromSingles(backup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Attempting to connect to node=%x (addrs=%v) to "+
|
||||
"restore ChannelPoint(%v)",
|
||||
backup.RemoteNodePub.SerializeCompressed(),
|
||||
newLogClosure(func() string {
|
||||
return spew.Sdump(backup.Addresses)
|
||||
}), backup.FundingOutpoint)
|
||||
|
||||
err = peerConnector.ConnectPeer(
|
||||
backup.RemoteNodePub, backup.Addresses,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(roasbeef): to handle case where node has changed addrs,
|
||||
// need to subscribe to new updates for target node pub to
|
||||
// attempt to connect to other addrs
|
||||
//
|
||||
// * just to to fresh w/ call to node addrs and de-dup?
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(roasbeef): more specific keychain interface?
|
||||
|
||||
// UnpackAndRecoverSingles is a one-shot method, that given a set of packed
|
||||
// single channel backups, will restore the channel state to a channel shell,
|
||||
// and also reach out to connect to any of the known node addresses for that
|
||||
// channel. It is assumes that after this method exists, if a connection we
|
||||
// able to be established, then then PeerConnector will continue to attempt to
|
||||
// re-establish a persistent connection in the background.
|
||||
func UnpackAndRecoverSingles(singles PackedSingles,
|
||||
keyChain keychain.KeyRing, restorer ChannelRestorer,
|
||||
peerConnector PeerConnector) error {
|
||||
|
||||
chanBackups, err := singles.Unpack(keyChain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return Recover(chanBackups, restorer, peerConnector)
|
||||
}
|
||||
|
||||
// UnpackAndRecoverMulti is a one-shot method, that given a set of packed
|
||||
// multi-channel backups, will restore the channel states to channel shells,
|
||||
// and also reach out to connect to any of the known node addresses for that
|
||||
// channel. It is assumes that after this method exists, if a connection we
|
||||
// able to be established, then then PeerConnector will continue to attempt to
|
||||
// re-establish a persistent connection in the background.
|
||||
func UnpackAndRecoverMulti(packedMulti PackedMulti,
|
||||
keyChain keychain.KeyRing, restorer ChannelRestorer,
|
||||
peerConnector PeerConnector) error {
|
||||
|
||||
chanBackups, err := packedMulti.Unpack(keyChain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return Recover(chanBackups.StaticBackups, restorer, peerConnector)
|
||||
}
|
232
chanbackup/recover_test.go
Normal file
232
chanbackup/recover_test.go
Normal file
@ -0,0 +1,232 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
)
|
||||
|
||||
type mockChannelRestorer struct {
|
||||
fail bool
|
||||
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (m *mockChannelRestorer) RestoreChansFromSingles(...Single) error {
|
||||
if m.fail {
|
||||
return fmt.Errorf("fail")
|
||||
}
|
||||
|
||||
m.callCount++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockPeerConnector struct {
|
||||
fail bool
|
||||
|
||||
callCount int
|
||||
}
|
||||
|
||||
func (m *mockPeerConnector) ConnectPeer(node *btcec.PublicKey,
|
||||
addrs []net.Addr) error {
|
||||
|
||||
if m.fail {
|
||||
return fmt.Errorf("fail")
|
||||
}
|
||||
|
||||
m.callCount++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestUnpackAndRecoverSingles tests that we're able to properly unpack and
|
||||
// recover a set of packed singles.
|
||||
func TestUnpackAndRecoverSingles(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
// First, we'll create a number of single chan backups that we'll
|
||||
// shortly back to so we can begin our recovery attempt.
|
||||
numSingles := 10
|
||||
backups := make([]Single, 0, numSingles)
|
||||
var packedBackups PackedSingles
|
||||
for i := 0; i < numSingles; i++ {
|
||||
channel, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable make channel: %v", err)
|
||||
}
|
||||
|
||||
single := NewSingle(channel, nil)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := single.PackToWriter(&b, keyRing); err != nil {
|
||||
t.Fatalf("unable to pack single: %v", err)
|
||||
}
|
||||
|
||||
backups = append(backups, single)
|
||||
packedBackups = append(packedBackups, b.Bytes())
|
||||
}
|
||||
|
||||
chanRestorer := mockChannelRestorer{}
|
||||
peerConnector := mockPeerConnector{}
|
||||
|
||||
// Now that we have our backups (packed and unpacked), we'll attempt to
|
||||
// restore them all in a single batch.
|
||||
|
||||
// If we make the channel restore fail, then the entire method should
|
||||
// as well
|
||||
chanRestorer.fail = true
|
||||
err := UnpackAndRecoverSingles(
|
||||
packedBackups, keyRing, &chanRestorer, &peerConnector,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("restoration should have failed")
|
||||
}
|
||||
|
||||
chanRestorer.fail = false
|
||||
|
||||
// If we make the peer connector fail, then the entire method should as
|
||||
// well
|
||||
peerConnector.fail = true
|
||||
err = UnpackAndRecoverSingles(
|
||||
packedBackups, keyRing, &chanRestorer, &peerConnector,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("restoration should have failed")
|
||||
}
|
||||
|
||||
chanRestorer.callCount--
|
||||
peerConnector.fail = false
|
||||
|
||||
// Next, we'll ensure that if all the interfaces function as expected,
|
||||
// then the channels will properly be unpacked and restored.
|
||||
err = UnpackAndRecoverSingles(
|
||||
packedBackups, keyRing, &chanRestorer, &peerConnector,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to recover chans: %v", err)
|
||||
}
|
||||
|
||||
// Both the restorer, and connector should have been called 10 times,
|
||||
// once for each backup.
|
||||
if chanRestorer.callCount != numSingles {
|
||||
t.Fatalf("expected %v calls, instead got %v",
|
||||
numSingles, chanRestorer.callCount)
|
||||
}
|
||||
if peerConnector.callCount != numSingles {
|
||||
t.Fatalf("expected %v calls, instead got %v",
|
||||
numSingles, peerConnector.callCount)
|
||||
}
|
||||
|
||||
// If we modify the keyRing, then unpacking should fail.
|
||||
keyRing.fail = true
|
||||
err = UnpackAndRecoverSingles(
|
||||
packedBackups, keyRing, &chanRestorer, &peerConnector,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("unpacking should have failed")
|
||||
}
|
||||
|
||||
// TODO(roasbeef): verify proper call args
|
||||
}
|
||||
|
||||
// TestUnpackAndRecoverMulti tests that we're able to properly unpack and
|
||||
// recover a packed multi.
|
||||
func TestUnpackAndRecoverMulti(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
// First, we'll create a number of single chan backups that we'll
|
||||
// shortly back to so we can begin our recovery attempt.
|
||||
numSingles := 10
|
||||
backups := make([]Single, 0, numSingles)
|
||||
for i := 0; i < numSingles; i++ {
|
||||
channel, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable make channel: %v", err)
|
||||
}
|
||||
|
||||
single := NewSingle(channel, nil)
|
||||
|
||||
backups = append(backups, single)
|
||||
}
|
||||
|
||||
multi := Multi{
|
||||
StaticBackups: backups,
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := multi.PackToWriter(&b, keyRing); err != nil {
|
||||
t.Fatalf("unable to pack multi: %v", err)
|
||||
}
|
||||
|
||||
// Next, we'll pack the set of singles into a packed multi, and also
|
||||
// create the set of interfaces we need to carry out the remainder of
|
||||
// the test.
|
||||
packedMulti := PackedMulti(b.Bytes())
|
||||
|
||||
chanRestorer := mockChannelRestorer{}
|
||||
peerConnector := mockPeerConnector{}
|
||||
|
||||
// If we make the channel restore fail, then the entire method should
|
||||
// as well
|
||||
chanRestorer.fail = true
|
||||
err := UnpackAndRecoverMulti(
|
||||
packedMulti, keyRing, &chanRestorer, &peerConnector,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("restoration should have failed")
|
||||
}
|
||||
|
||||
chanRestorer.fail = false
|
||||
|
||||
// If we make the peer connector fail, then the entire method should as
|
||||
// well
|
||||
peerConnector.fail = true
|
||||
err = UnpackAndRecoverMulti(
|
||||
packedMulti, keyRing, &chanRestorer, &peerConnector,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("restoration should have failed")
|
||||
}
|
||||
|
||||
chanRestorer.callCount--
|
||||
peerConnector.fail = false
|
||||
|
||||
// Next, we'll ensure that if all the interfaces function as expected,
|
||||
// then the channels will properly be unpacked and restored.
|
||||
err = UnpackAndRecoverMulti(
|
||||
packedMulti, keyRing, &chanRestorer, &peerConnector,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to recover chans: %v", err)
|
||||
}
|
||||
|
||||
// Both the restorer, and connector should have been called 10 times,
|
||||
// once for each backup.
|
||||
if chanRestorer.callCount != numSingles {
|
||||
t.Fatalf("expected %v calls, instead got %v",
|
||||
numSingles, chanRestorer.callCount)
|
||||
}
|
||||
if peerConnector.callCount != numSingles {
|
||||
t.Fatalf("expected %v calls, instead got %v",
|
||||
numSingles, peerConnector.callCount)
|
||||
}
|
||||
|
||||
// If we modify the keyRing, then unpacking should fail.
|
||||
keyRing.fail = true
|
||||
err = UnpackAndRecoverMulti(
|
||||
packedMulti, keyRing, &chanRestorer, &peerConnector,
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("unpacking should have failed")
|
||||
}
|
||||
|
||||
// TODO(roasbeef): verify proper call args
|
||||
}
|
346
chanbackup/single.go
Normal file
346
chanbackup/single.go
Normal file
@ -0,0 +1,346 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// SingleBackupVersion denotes the version of the single static channel backup.
|
||||
// Based on this version, we know how to pack/unpack serialized versions of the
|
||||
// backup.
|
||||
type SingleBackupVersion byte
|
||||
|
||||
const (
|
||||
// DefaultSingleVersion is the defautl version of the single channel
|
||||
// backup. The seralized version of this static channel backup is
|
||||
// simply: version || SCB. Where SCB is the known format of the
|
||||
// version.
|
||||
DefaultSingleVersion = 0
|
||||
)
|
||||
|
||||
// Single is a static description of an existing channel that can be used for
|
||||
// the purposes of backing up. The fields in this struct allow a node to
|
||||
// recover the settled funds within a channel in the case of partial or
|
||||
// complete data loss. We provide the network address that we last used to
|
||||
// connect to the peer as well, in case the node stops advertising the IP on
|
||||
// the network for whatever reason.
|
||||
//
|
||||
// TODO(roasbeef): suffix version into struct?
|
||||
type Single struct {
|
||||
// Version is the version that should be observed when attempting to
|
||||
// pack the single backup.
|
||||
Version SingleBackupVersion
|
||||
|
||||
// ChainHash is a hash which represents the blockchain that this
|
||||
// channel will be opened within. This value is typically the genesis
|
||||
// hash. In the case that the original chain went through a contentious
|
||||
// hard-fork, then this value will be tweaked using the unique fork
|
||||
// point on each branch.
|
||||
ChainHash chainhash.Hash
|
||||
|
||||
// FundingOutpoint is the outpoint of the final funding transaction.
|
||||
// This value uniquely and globally identities the channel within the
|
||||
// target blockchain as specified by the chain hash parameter.
|
||||
FundingOutpoint wire.OutPoint
|
||||
|
||||
// ShortChannelID encodes the exact location in the chain in which the
|
||||
// channel was initially confirmed. This includes: the block height,
|
||||
// transaction index, and the output within the target transaction.
|
||||
ShortChannelID lnwire.ShortChannelID
|
||||
|
||||
// RemoteNodePub is the identity public key of the remote node this
|
||||
// channel has been established with.
|
||||
RemoteNodePub *btcec.PublicKey
|
||||
|
||||
// Addresses is a list of IP address in which either we were able to
|
||||
// reach the node over in the past, OR we received an incoming
|
||||
// authenticated connection for the stored identity public key.
|
||||
Addresses []net.Addr
|
||||
|
||||
// CsvDelay is the local CSV delay used within the channel. We may need
|
||||
// this value to reconstruct our script to recover the funds on-chain
|
||||
// after a force close.
|
||||
CsvDelay uint16
|
||||
|
||||
// PaymentBasePoint describes how to derive base public that's used to
|
||||
// deriving the key used within the non-delayed pay-to-self output on
|
||||
// the commitment transaction for a node. With this information, we can
|
||||
// re-derive the private key needed to sweep the funds on-chain.
|
||||
PaymentBasePoint keychain.KeyLocator
|
||||
|
||||
// ShaChainRootDesc describes how to derive the private key that was
|
||||
// used as the shachain root for this channel.
|
||||
ShaChainRootDesc keychain.KeyDescriptor
|
||||
}
|
||||
|
||||
// NewSingle creates a new static channel backup based on an existing open
|
||||
// channel. We also pass in the set of addresses that we used in the past to
|
||||
// connect to the channel peer.
|
||||
func NewSingle(channel *channeldb.OpenChannel,
|
||||
nodeAddrs []net.Addr) Single {
|
||||
|
||||
chanCfg := channel.LocalChanCfg
|
||||
|
||||
// TODO(roasbeef): update after we start to store the KeyLoc for
|
||||
// shachain root
|
||||
|
||||
// We'll need to obtain the shachain root which is derived directly
|
||||
// from a private key in our keychain.
|
||||
var b bytes.Buffer
|
||||
channel.RevocationProducer.Encode(&b) // Can't return an error.
|
||||
|
||||
// Once we have the root, we'll make a public key from it, such that
|
||||
// the backups plaintext don't carry any private information. When we
|
||||
// go to recover, we'll present this in order to derive the private
|
||||
// key.
|
||||
_, shaChainPoint := btcec.PrivKeyFromBytes(btcec.S256(), b.Bytes())
|
||||
|
||||
return Single{
|
||||
ChainHash: channel.ChainHash,
|
||||
FundingOutpoint: channel.FundingOutpoint,
|
||||
ShortChannelID: channel.ShortChannelID,
|
||||
RemoteNodePub: channel.IdentityPub,
|
||||
Addresses: nodeAddrs,
|
||||
CsvDelay: chanCfg.CsvDelay,
|
||||
PaymentBasePoint: chanCfg.PaymentBasePoint.KeyLocator,
|
||||
ShaChainRootDesc: keychain.KeyDescriptor{
|
||||
PubKey: shaChainPoint,
|
||||
KeyLocator: keychain.KeyLocator{
|
||||
Family: keychain.KeyFamilyRevocationRoot,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize attempts to write out the serialized version of the target
|
||||
// StaticChannelBackup into the passed io.Writer.
|
||||
func (s *Single) Serialize(w io.Writer) error {
|
||||
// Check to ensure that we'll only attempt to serialize a version that
|
||||
// we're aware of.
|
||||
switch s.Version {
|
||||
case DefaultSingleVersion:
|
||||
default:
|
||||
return fmt.Errorf("unable to serialize w/ unknown "+
|
||||
"version: %v", s.Version)
|
||||
}
|
||||
|
||||
// If the sha chain root has specified a public key (which is
|
||||
// optional), then we'll encode it now.
|
||||
var shaChainPub [33]byte
|
||||
if s.ShaChainRootDesc.PubKey != nil {
|
||||
copy(
|
||||
shaChainPub[:],
|
||||
s.ShaChainRootDesc.PubKey.SerializeCompressed(),
|
||||
)
|
||||
}
|
||||
|
||||
// First we gather the SCB as is into a temporary buffer so we can
|
||||
// determine the total length. Before we write out the serialized SCB,
|
||||
// we write the length which allows us to skip any Singles that we
|
||||
// don't know of when decoding a multi.
|
||||
var singleBytes bytes.Buffer
|
||||
if err := lnwire.WriteElements(
|
||||
&singleBytes,
|
||||
s.ChainHash[:],
|
||||
s.FundingOutpoint,
|
||||
s.ShortChannelID,
|
||||
s.RemoteNodePub,
|
||||
s.Addresses,
|
||||
s.CsvDelay,
|
||||
uint32(s.PaymentBasePoint.Family),
|
||||
s.PaymentBasePoint.Index,
|
||||
shaChainPub[:],
|
||||
uint32(s.ShaChainRootDesc.KeyLocator.Family),
|
||||
s.ShaChainRootDesc.KeyLocator.Index,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return lnwire.WriteElements(
|
||||
w,
|
||||
byte(s.Version),
|
||||
uint16(len(singleBytes.Bytes())),
|
||||
singleBytes.Bytes(),
|
||||
)
|
||||
}
|
||||
|
||||
// PackToWriter is similar to the Serialize method, but takes the operation a
|
||||
// step further by encryption the raw bytes of the static channel back up. For
|
||||
// encryption we use the chacah20poly1305 AEAD cipher with a 24 byte nonce and
|
||||
// 32-byte key size. We use a 24-byte nonce, as we can't ensure that we have a
|
||||
// global counter to use as a sequence number for nonces, and want to ensure
|
||||
// that we're able to decrypt these blobs without any additional context. We
|
||||
// derive the key that we use for encryption via a SHA2 operation of the with
|
||||
// the golden keychain.KeyFamilyStaticBackup base encryption key. We then take
|
||||
// the serialized resulting shared secret point, and hash it using sha256 to
|
||||
// obtain the key that we'll use for encryption. When using the AEAD, we pass
|
||||
// the nonce as associated data such that we'll be able to package the two
|
||||
// together for storage. Before writing out the encrypted payload, we prepend
|
||||
// the nonce to the final blob.
|
||||
func (s *Single) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error {
|
||||
// First, we'll serialize the SCB (StaticChannelBackup) into a
|
||||
// temporary buffer so we can store it in a temporary place before we
|
||||
// go to encrypt the entire thing.
|
||||
var rawBytes bytes.Buffer
|
||||
if err := s.Serialize(&rawBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, we'll encrypt the raw serialized SCB (using the nonce as
|
||||
// associated data), and write out the ciphertext prepend with the
|
||||
// nonce that we used to the passed io.Reader.
|
||||
return encryptPayloadToWriter(rawBytes, w, keyRing)
|
||||
}
|
||||
|
||||
// Deserialize attempts to read the raw plaintext serialized SCB from the
|
||||
// passed io.Reader. If the method is successful, then the target
|
||||
// StaticChannelBackup will be fully populated.
|
||||
func (s *Single) Deserialize(r io.Reader) error {
|
||||
// First, we'll need to read the version of this single-back up so we
|
||||
// can know how to unpack each of the SCB.
|
||||
var version byte
|
||||
err := lnwire.ReadElements(r, &version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.Version = SingleBackupVersion(version)
|
||||
|
||||
switch s.Version {
|
||||
case DefaultSingleVersion:
|
||||
default:
|
||||
return fmt.Errorf("unable to de-serialize w/ unknown "+
|
||||
"version: %v", s.Version)
|
||||
}
|
||||
|
||||
var length uint16
|
||||
if err := lnwire.ReadElements(r, &length); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = lnwire.ReadElements(
|
||||
r, s.ChainHash[:], &s.FundingOutpoint, &s.ShortChannelID,
|
||||
&s.RemoteNodePub, &s.Addresses, &s.CsvDelay,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var keyFam uint32
|
||||
if err := lnwire.ReadElements(r, &keyFam); err != nil {
|
||||
return err
|
||||
}
|
||||
s.PaymentBasePoint.Family = keychain.KeyFamily(keyFam)
|
||||
|
||||
err = lnwire.ReadElements(r, &s.PaymentBasePoint.Index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, we'll parse out the ShaChainRootDesc.
|
||||
var (
|
||||
shaChainPub [33]byte
|
||||
zeroPub [33]byte
|
||||
)
|
||||
if err := lnwire.ReadElements(r, shaChainPub[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Since this field is optional, we'll check to see if the pubkey has
|
||||
// ben specified or not.
|
||||
if !bytes.Equal(shaChainPub[:], zeroPub[:]) {
|
||||
s.ShaChainRootDesc.PubKey, err = btcec.ParsePubKey(
|
||||
shaChainPub[:], btcec.S256(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var shaKeyFam uint32
|
||||
if err := lnwire.ReadElements(r, &shaKeyFam); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ShaChainRootDesc.KeyLocator.Family = keychain.KeyFamily(shaKeyFam)
|
||||
|
||||
return lnwire.ReadElements(r, &s.ShaChainRootDesc.KeyLocator.Index)
|
||||
}
|
||||
|
||||
// UnpackFromReader is similar to Deserialize method, but it expects the passed
|
||||
// io.Reader to contain an encrypt SCB. Refer to the SerializeAndEncrypt method
|
||||
// for details w.r.t the encryption scheme used. If we're unable to decrypt the
|
||||
// payload for whatever reason (wrong key, wrong nonce, etc), then this method
|
||||
// will return an error.
|
||||
func (s *Single) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error {
|
||||
plaintext, err := decryptPayloadFromReader(r, keyRing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, we'll pack the bytes into a reader to we can deserialize
|
||||
// the plaintext bytes of the SCB.
|
||||
backupReader := bytes.NewReader(plaintext)
|
||||
return s.Deserialize(backupReader)
|
||||
}
|
||||
|
||||
// PackStaticChanBackups accepts a set of existing open channels, and a
|
||||
// keychain.KeyRing, and returns a map of outpoints to the serialized+encrypted
|
||||
// static channel backups. The passed keyRing should be backed by the users
|
||||
// root HD seed in order to ensure full determinism.
|
||||
func PackStaticChanBackups(backups []Single,
|
||||
keyRing keychain.KeyRing) (map[wire.OutPoint][]byte, error) {
|
||||
|
||||
packedBackups := make(map[wire.OutPoint][]byte)
|
||||
for _, chanBackup := range backups {
|
||||
chanPoint := chanBackup.FundingOutpoint
|
||||
|
||||
var b bytes.Buffer
|
||||
err := chanBackup.PackToWriter(&b, keyRing)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to pack chan backup "+
|
||||
"for %v: %v", chanPoint, err)
|
||||
}
|
||||
|
||||
packedBackups[chanPoint] = b.Bytes()
|
||||
}
|
||||
|
||||
return packedBackups, nil
|
||||
}
|
||||
|
||||
// PackedSingles represents a series of fully packed SCBs. This may be the
|
||||
// combination of a series of individual SCBs in order to batch their
|
||||
// unpacking.
|
||||
type PackedSingles [][]byte
|
||||
|
||||
// Unpack attempts to decrypt the passed set of encrypted SCBs and deserialize
|
||||
// each one into a new SCB struct. The passed keyRing should be backed by the
|
||||
// same HD seed as was used to encrypt the set of backups in the first place.
|
||||
// If we're unable to decrypt any of the back ups, then we'll return an error.
|
||||
func (p PackedSingles) Unpack(keyRing keychain.KeyRing) ([]Single, error) {
|
||||
|
||||
backups := make([]Single, len(p))
|
||||
for i, encryptedBackup := range p {
|
||||
var backup Single
|
||||
|
||||
backupReader := bytes.NewReader(encryptedBackup)
|
||||
err := backup.UnpackFromReader(backupReader, keyRing)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backups[i] = backup
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
// TODO(roasbeef): make codec package?
|
342
chanbackup/single_test.go
Normal file
342
chanbackup/single_test.go
Normal file
@ -0,0 +1,342 @@
|
||||
package chanbackup
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/shachain"
|
||||
)
|
||||
|
||||
var (
|
||||
chainHash = chainhash.Hash{
|
||||
0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab,
|
||||
0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4,
|
||||
0x4f, 0x2f, 0x6f, 0x25, 0x18, 0xa3, 0xef, 0xb9,
|
||||
0x64, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53,
|
||||
}
|
||||
|
||||
op = wire.OutPoint{
|
||||
Hash: chainHash,
|
||||
Index: 4,
|
||||
}
|
||||
|
||||
addr1, _ = net.ResolveTCPAddr("tcp", "10.0.0.2:9000")
|
||||
addr2, _ = net.ResolveTCPAddr("tcp", "10.0.0.3:9000")
|
||||
)
|
||||
|
||||
func assertSingleEqual(t *testing.T, a, b Single) {
|
||||
t.Helper()
|
||||
|
||||
if a.Version != b.Version {
|
||||
t.Fatalf("versions don't match: %v vs %v", a.Version,
|
||||
b.Version)
|
||||
}
|
||||
if a.ChainHash != b.ChainHash {
|
||||
t.Fatalf("chainhash doesn't match: %v vs %v", a.ChainHash,
|
||||
b.ChainHash)
|
||||
}
|
||||
if a.FundingOutpoint != b.FundingOutpoint {
|
||||
t.Fatalf("chan point doesn't match: %v vs %v",
|
||||
a.FundingOutpoint, b.FundingOutpoint)
|
||||
}
|
||||
if a.ShortChannelID != b.ShortChannelID {
|
||||
t.Fatalf("chan id doesn't match: %v vs %v",
|
||||
a.ShortChannelID, b.ShortChannelID)
|
||||
}
|
||||
if !a.RemoteNodePub.IsEqual(b.RemoteNodePub) {
|
||||
t.Fatalf("node pubs don't match %x vs %x",
|
||||
a.RemoteNodePub.SerializeCompressed(),
|
||||
b.RemoteNodePub.SerializeCompressed())
|
||||
}
|
||||
if a.CsvDelay != b.CsvDelay {
|
||||
t.Fatalf("csv delay doesn't match: %v vs %v", a.CsvDelay,
|
||||
b.CsvDelay)
|
||||
}
|
||||
if !reflect.DeepEqual(a.PaymentBasePoint, b.PaymentBasePoint) {
|
||||
t.Fatalf("base point doesn't match: %v vs %v",
|
||||
spew.Sdump(a.PaymentBasePoint),
|
||||
spew.Sdump(b.PaymentBasePoint))
|
||||
}
|
||||
if !reflect.DeepEqual(a.ShaChainRootDesc, b.ShaChainRootDesc) {
|
||||
t.Fatalf("sha chain point doesn't match: %v vs %v",
|
||||
spew.Sdump(a.PaymentBasePoint),
|
||||
spew.Sdump(b.PaymentBasePoint))
|
||||
}
|
||||
|
||||
if len(a.Addresses) != len(b.Addresses) {
|
||||
t.Fatalf("expected %v addrs got %v", len(a.Addresses),
|
||||
len(b.Addresses))
|
||||
}
|
||||
for i := 0; i < len(a.Addresses); i++ {
|
||||
if a.Addresses[i].String() != b.Addresses[i].String() {
|
||||
t.Fatalf("addr mismatch: %v vs %v",
|
||||
a.Addresses[i], b.Addresses[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func genRandomOpenChannelShell() (*channeldb.OpenChannel, error) {
|
||||
var testPriv [32]byte
|
||||
if _, err := rand.Read(testPriv[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:])
|
||||
|
||||
var chanPoint wire.OutPoint
|
||||
if _, err := rand.Read(chanPoint.Hash[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pub.Curve = nil
|
||||
|
||||
chanPoint.Index = uint32(rand.Intn(math.MaxUint16))
|
||||
|
||||
var shaChainRoot [32]byte
|
||||
if _, err := rand.Read(shaChainRoot[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shaChainProducer := shachain.NewRevocationProducer(shaChainRoot)
|
||||
|
||||
return &channeldb.OpenChannel{
|
||||
ChainHash: chainHash,
|
||||
FundingOutpoint: chanPoint,
|
||||
ShortChannelID: lnwire.NewShortChanIDFromInt(
|
||||
uint64(rand.Int63()),
|
||||
),
|
||||
IdentityPub: pub,
|
||||
LocalChanCfg: channeldb.ChannelConfig{
|
||||
ChannelConstraints: channeldb.ChannelConstraints{
|
||||
CsvDelay: uint16(rand.Int63()),
|
||||
},
|
||||
PaymentBasePoint: keychain.KeyDescriptor{
|
||||
KeyLocator: keychain.KeyLocator{
|
||||
Family: keychain.KeyFamily(rand.Int63()),
|
||||
Index: uint32(rand.Int63()),
|
||||
},
|
||||
},
|
||||
},
|
||||
RevocationProducer: shaChainProducer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestSinglePackUnpack tests that we're able to unpack a previously packed
|
||||
// channel backup.
|
||||
func TestSinglePackUnpack(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given our test pub key, we'll create an open channel shell that
|
||||
// contains all the information we need to create a static channel
|
||||
// backup.
|
||||
channel, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to gen open channel: %v", err)
|
||||
}
|
||||
|
||||
singleChanBackup := NewSingle(channel, []net.Addr{addr1, addr2})
|
||||
singleChanBackup.RemoteNodePub.Curve = nil
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
versionTestCases := []struct {
|
||||
// version is the pack/unpack version that we should use to
|
||||
// decode/encode the final SCB.
|
||||
version SingleBackupVersion
|
||||
|
||||
// valid tests us if this test case should pass or not.
|
||||
valid bool
|
||||
}{
|
||||
// The default version, should pack/unpack with no problem.
|
||||
{
|
||||
version: DefaultSingleVersion,
|
||||
valid: true,
|
||||
},
|
||||
|
||||
// A non-default version, atm this should result in a failure.
|
||||
{
|
||||
version: 99,
|
||||
valid: false,
|
||||
},
|
||||
}
|
||||
for i, versionCase := range versionTestCases {
|
||||
// First, we'll re-assign SCB version to what was indicated in
|
||||
// the test case.
|
||||
singleChanBackup.Version = versionCase.version
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
err := singleChanBackup.PackToWriter(&b, keyRing)
|
||||
switch {
|
||||
// If this is a valid test case, and we failed, then we'll
|
||||
// return an error.
|
||||
case err != nil && versionCase.valid:
|
||||
t.Fatalf("#%v, unable to pack single: %v", i, err)
|
||||
|
||||
// If this is an invalid test case, and we passed it, then
|
||||
// we'll return an error.
|
||||
case err == nil && !versionCase.valid:
|
||||
t.Fatalf("#%v got nil error for invalid pack: %v",
|
||||
i, err)
|
||||
}
|
||||
|
||||
// If this is a valid test case, then we'll continue to ensure
|
||||
// we can unpack it, and also that if we mutate the packed
|
||||
// version, then we trigger an error.
|
||||
if versionCase.valid {
|
||||
var unpackedSingle Single
|
||||
err = unpackedSingle.UnpackFromReader(&b, keyRing)
|
||||
if err != nil {
|
||||
t.Fatalf("#%v unable to unpack single: %v",
|
||||
i, err)
|
||||
}
|
||||
unpackedSingle.RemoteNodePub.Curve = nil
|
||||
|
||||
assertSingleEqual(t, singleChanBackup, unpackedSingle)
|
||||
|
||||
// If this was a valid packing attempt, then we'll test
|
||||
// to ensure that if we mutate the version prepended to
|
||||
// the serialization, then unpacking will fail as well.
|
||||
var rawSingle bytes.Buffer
|
||||
err := unpackedSingle.Serialize(&rawSingle)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to serialize single: %v", err)
|
||||
}
|
||||
|
||||
rawBytes := rawSingle.Bytes()
|
||||
rawBytes[0] ^= 1
|
||||
|
||||
newReader := bytes.NewReader(rawBytes)
|
||||
err = unpackedSingle.Deserialize(newReader)
|
||||
if err == nil {
|
||||
t.Fatalf("#%v unpack with unknown version "+
|
||||
"should have failed", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestPackedSinglesUnpack tests that we're able to properly unpack a series of
|
||||
// packed singles.
|
||||
func TestPackedSinglesUnpack(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
// To start, we'll create 10 new singles, and them assemble their
|
||||
// packed forms into a slice.
|
||||
numSingles := 10
|
||||
packedSingles := make([][]byte, 0, numSingles)
|
||||
unpackedSingles := make([]Single, 0, numSingles)
|
||||
for i := 0; i < numSingles; i++ {
|
||||
channel, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to gen channel: %v", err)
|
||||
}
|
||||
|
||||
single := NewSingle(channel, nil)
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := single.PackToWriter(&b, keyRing); err != nil {
|
||||
t.Fatalf("unable to pack single: %v", err)
|
||||
}
|
||||
|
||||
packedSingles = append(packedSingles, b.Bytes())
|
||||
unpackedSingles = append(unpackedSingles, single)
|
||||
}
|
||||
|
||||
// With all singles packed, we'll create the grouped type and attempt
|
||||
// to Unpack all of them in a single go.
|
||||
freshSingles, err := PackedSingles(packedSingles).Unpack(keyRing)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unpack singles: %v", err)
|
||||
}
|
||||
|
||||
// The set of freshly unpacked singles should exactly match the initial
|
||||
// set of singles that we packed before.
|
||||
for i := 0; i < len(unpackedSingles); i++ {
|
||||
assertSingleEqual(t, unpackedSingles[i], freshSingles[i])
|
||||
}
|
||||
|
||||
// If we mutate one of the packed singles, then the entire method
|
||||
// should fail.
|
||||
packedSingles[0][0] ^= 1
|
||||
_, err = PackedSingles(packedSingles).Unpack(keyRing)
|
||||
if err == nil {
|
||||
t.Fatalf("unpack attempt should fail")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSinglePackStaticChanBackups tests that we're able to batch pack a set of
|
||||
// Singles, and then unpack them obtaining the same set of unpacked singles.
|
||||
func TestSinglePackStaticChanBackups(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRing := &mockKeyRing{}
|
||||
|
||||
// First, we'll create a set of random single, and along the way,
|
||||
// create a map that will let us look up each single by its chan point.
|
||||
numSingles := 10
|
||||
singleMap := make(map[wire.OutPoint]Single, numSingles)
|
||||
unpackedSingles := make([]Single, 0, numSingles)
|
||||
for i := 0; i < numSingles; i++ {
|
||||
channel, err := genRandomOpenChannelShell()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to gen channel: %v", err)
|
||||
}
|
||||
|
||||
single := NewSingle(channel, nil)
|
||||
|
||||
singleMap[channel.FundingOutpoint] = single
|
||||
unpackedSingles = append(unpackedSingles, single)
|
||||
}
|
||||
|
||||
// Now that we have all of our singles are created, we'll attempt to
|
||||
// pack them all in a single batch.
|
||||
packedSingleMap, err := PackStaticChanBackups(unpackedSingles, keyRing)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to pack backups: %v", err)
|
||||
}
|
||||
|
||||
// With our packed singles obtained, we'll ensure that each of them
|
||||
// match their unpacked counterparts after they themselves have been
|
||||
// unpacked.
|
||||
for chanPoint, single := range singleMap {
|
||||
packedSingles, ok := packedSingleMap[chanPoint]
|
||||
if !ok {
|
||||
t.Fatalf("unable to find single %v", chanPoint)
|
||||
}
|
||||
|
||||
var freshSingle Single
|
||||
err := freshSingle.UnpackFromReader(
|
||||
bytes.NewReader(packedSingles), keyRing,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unpack single: %v", err)
|
||||
}
|
||||
|
||||
assertSingleEqual(t, single, freshSingle)
|
||||
}
|
||||
|
||||
// If we attempt to pack again, but force the key ring to fail, then
|
||||
// the entire method should fail.
|
||||
_, err = PackStaticChanBackups(
|
||||
unpackedSingles, &mockKeyRing{true},
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatalf("pack attempt should fail")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(roasbsef): fuzz parsing
|
@ -83,6 +83,13 @@ const (
|
||||
// in order to establish a transport session with us on the Lightning
|
||||
// p2p level (BOLT-0008).
|
||||
KeyFamilyNodeKey KeyFamily = 6
|
||||
|
||||
// KeyFamilyStaticBackup is the family of keys that will be used to
|
||||
// derive keys that we use to encrypt and decrypt our set of static
|
||||
// backups. These backups may either be stored within watch towers for
|
||||
// a payment, or self stored on disk in a single file containing all
|
||||
// the static channel backups.
|
||||
KeyFamilyStaticBackup KeyFamily = 7
|
||||
)
|
||||
|
||||
// KeyLocator is a two-tuple that can be used to derive *any* key that has ever
|
||||
|
Loading…
Reference in New Issue
Block a user