diff --git a/chanbackup/backup.go b/chanbackup/backup.go new file mode 100644 index 00000000..ca3698a5 --- /dev/null +++ b/chanbackup/backup.go @@ -0,0 +1,99 @@ +package chanbackup + +import ( + "fmt" + "net" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb" +) + +// LiveChannelSource is an interface that allows us to query for the set of +// live channels. A live channel is one that is open, and has not had a +// commitment transaction broadcast. +type LiveChannelSource interface { + // FetchAllChannels returns all known live channels. + FetchAllChannels() ([]*channeldb.OpenChannel, error) + + // FetchChannel attempts to locate a live channel identified by the + // passed chanPoint. + FetchChannel(chanPoint wire.OutPoint) (*channeldb.OpenChannel, error) + + // AddrsForNode returns all known addresses for the target node public + // key. + AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) +} + +// assembleChanBackup attempts to assemble a static channel backup for the +// passed open channel. The backup includes all information required to restore +// the channel, as well as addressing information so we can find the peer and +// reconnect to them to initiate the protocol. +func assembleChanBackup(chanSource LiveChannelSource, + openChan *channeldb.OpenChannel) (*Single, error) { + + log.Debugf("Crafting backup for ChannelPoint(%v)", + openChan.FundingOutpoint) + + // First, we'll query the channel source to obtain all the addresses + // that are are associated with the peer for this channel. + nodeAddrs, err := chanSource.AddrsForNode(openChan.IdentityPub) + if err != nil { + return nil, err + } + + single := NewSingle(openChan, nodeAddrs) + + return &single, nil +} + +// FetchBackupForChan attempts to create a plaintext static channel backup for +// the target channel identified by its channel point. If we're unable to find +// the target channel, then an error will be returned. +func FetchBackupForChan(chanPoint wire.OutPoint, + chanSource LiveChannelSource) (*Single, error) { + + // First, we'll query the channel source to see if the channel is known + // and open within the database. + targetChan, err := chanSource.FetchChannel(chanPoint) + if err != nil { + // If we can't find the channel, then we return with an error, + // as we have nothing to backup. + return nil, fmt.Errorf("unable to find target channel") + } + + // Once we have the target channel, we can assemble the backup using + // the source to obtain any extra information that we may need. + staticChanBackup, err := assembleChanBackup(chanSource, targetChan) + if err != nil { + return nil, fmt.Errorf("unable to create chan backup: %v", err) + } + + return staticChanBackup, nil +} + +// FetchStaticChanBackups will return a plaintext static channel back up for +// all known active/open channels within the passed channel source. +func FetchStaticChanBackups(chanSource LiveChannelSource) ([]Single, error) { + // First, we'll query the backup source for information concerning all + // currently open and available channels. + openChans, err := chanSource.FetchAllChannels() + if err != nil { + return nil, err + } + + // Now that we have all the channels, we'll use the chanSource to + // obtain any auxiliary information we need to craft a backup for each + // channel. + staticChanBackups := make([]Single, len(openChans)) + for i, openChan := range openChans { + chanBackup, err := assembleChanBackup(chanSource, openChan) + if err != nil { + return nil, err + } + + staticChanBackups[i] = *chanBackup + } + + return staticChanBackups, nil +} diff --git a/chanbackup/backup_test.go b/chanbackup/backup_test.go new file mode 100644 index 00000000..ea8bfcea --- /dev/null +++ b/chanbackup/backup_test.go @@ -0,0 +1,197 @@ +package chanbackup + +import ( + "fmt" + "net" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb" +) + +type mockChannelSource struct { + chans map[wire.OutPoint]*channeldb.OpenChannel + + failQuery bool + + addrs map[[33]byte][]net.Addr +} + +func newMockChannelSource() *mockChannelSource { + return &mockChannelSource{ + chans: make(map[wire.OutPoint]*channeldb.OpenChannel), + addrs: make(map[[33]byte][]net.Addr), + } +} + +func (m *mockChannelSource) FetchAllChannels() ([]*channeldb.OpenChannel, error) { + if m.failQuery { + return nil, fmt.Errorf("fail") + } + + chans := make([]*channeldb.OpenChannel, 0, len(m.chans)) + for _, channel := range m.chans { + chans = append(chans, channel) + } + + return chans, nil +} + +func (m *mockChannelSource) FetchChannel(chanPoint wire.OutPoint) (*channeldb.OpenChannel, error) { + if m.failQuery { + return nil, fmt.Errorf("fail") + } + + channel, ok := m.chans[chanPoint] + if !ok { + return nil, fmt.Errorf("can't find chan") + } + + return channel, nil +} + +func (m *mockChannelSource) addAddrsForNode(nodePub *btcec.PublicKey, addrs []net.Addr) { + var nodeKey [33]byte + copy(nodeKey[:], nodePub.SerializeCompressed()) + + m.addrs[nodeKey] = addrs +} + +func (m *mockChannelSource) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { + if m.failQuery { + return nil, fmt.Errorf("fail") + } + + var nodeKey [33]byte + copy(nodeKey[:], nodePub.SerializeCompressed()) + + addrs, ok := m.addrs[nodeKey] + if !ok { + return nil, fmt.Errorf("can't find addr") + } + + return addrs, nil +} + +// TestFetchBackupForChan tests that we're able to construct a single channel +// backup for channels that are known, unknown, and also channels in which we +// can find addresses for and otherwise. +func TestFetchBackupForChan(t *testing.T) { + t.Parallel() + + // First, we'll make two channels, only one of them will have all the + // information we need to construct set of backups for them. + randomChan1, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to generate chan: %v", err) + } + randomChan2, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to generate chan: %v", err) + } + + chanSource := newMockChannelSource() + chanSource.chans[randomChan1.FundingOutpoint] = randomChan1 + chanSource.chans[randomChan2.FundingOutpoint] = randomChan2 + + chanSource.addAddrsForNode(randomChan1.IdentityPub, []net.Addr{addr1}) + + testCases := []struct { + chanPoint wire.OutPoint + + pass bool + }{ + // Able to find channel, and addresses, should pass. + { + chanPoint: randomChan1.FundingOutpoint, + pass: true, + }, + + // Able to find channel, not able to find addrs, should fail. + { + chanPoint: randomChan2.FundingOutpoint, + pass: false, + }, + + // Not able to find channel, should fail. + { + chanPoint: op, + pass: false, + }, + } + for i, testCase := range testCases { + _, err := FetchBackupForChan(testCase.chanPoint, chanSource) + switch { + // If this is a valid test case, and we failed, then we'll + // return an error. + case err != nil && testCase.pass: + t.Fatalf("#%v, unable to make chan backup: %v", i, err) + + // If this is an invalid test case, and we passed it, then + // we'll return an error. + case err == nil && !testCase.pass: + t.Fatalf("#%v got nil error for invalid req: %v", + i, err) + } + } +} + +// TestFetchStaticChanBackups tests that we're able to properly query the +// channel source for all channels and construct a Single for each channel. +func TestFetchStaticChanBackups(t *testing.T) { + t.Parallel() + + // First, we'll make the set of channels that we want to seed the + // channel source with. Both channels will be fully populated in the + // channel source. + const numChans = 2 + randomChan1, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to generate chan: %v", err) + } + randomChan2, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to generate chan: %v", err) + } + + chanSource := newMockChannelSource() + chanSource.chans[randomChan1.FundingOutpoint] = randomChan1 + chanSource.chans[randomChan2.FundingOutpoint] = randomChan2 + chanSource.addAddrsForNode(randomChan1.IdentityPub, []net.Addr{addr1}) + chanSource.addAddrsForNode(randomChan2.IdentityPub, []net.Addr{addr2}) + + // With the channel source populated, we'll now attempt to create a set + // of backups for all the channels. This should succeed, as all items + // are populated within the channel source. + backups, err := FetchStaticChanBackups(chanSource) + if err != nil { + t.Fatalf("unable to create chan back ups: %v", err) + } + + if len(backups) != numChans { + t.Fatalf("expected %v chans, instead got %v", numChans, + len(backups)) + } + + // We'll attempt to create a set up backups again, but this time the + // second channel will have missing information, which should cause the + // query to fail. + var n [33]byte + copy(n[:], randomChan2.IdentityPub.SerializeCompressed()) + delete(chanSource.addrs, n) + + _, err = FetchStaticChanBackups(chanSource) + if err == nil { + t.Fatalf("query with incomplete information should fail") + } + + // To wrap up, we'll ensure that if we're unable to query the channel + // source at all, then we'll fail as well. + chanSource = newMockChannelSource() + chanSource.failQuery = true + _, err = FetchStaticChanBackups(chanSource) + if err == nil { + t.Fatalf("query should fail") + } +} diff --git a/chanbackup/backupfile.go b/chanbackup/backupfile.go new file mode 100644 index 00000000..0dbf9dc4 --- /dev/null +++ b/chanbackup/backupfile.go @@ -0,0 +1,160 @@ +package chanbackup + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + + "github.com/lightningnetwork/lnd/keychain" +) + +const ( + // DefaultBackupFileName is the default name of the auto updated static + // channel backup fie. + DefaultBackupFileName = "channel.backup" + + // DefaultTempBackupFileName is the default name of the temporary SCB + // file that we'll use to atomically update the primary back up file + // when new channel are detected. + DefaultTempBackupFileName = "temp-dont-use.backup" +) + +var ( + // ErrNoBackupFileExists is returned if caller attempts to call + // UpdateAndSwap with the file name not set. + ErrNoBackupFileExists = fmt.Errorf("back up file name not set") + + // ErrNoTempBackupFile is returned if caller attempts to call + // UpdateAndSwap with the temp back up file name not set. + ErrNoTempBackupFile = fmt.Errorf("temp backup file not set") +) + +// MultiFile represents a file on disk that a caller can use to read the packed +// multi backup into an unpacked one, and also atomically update the contents +// on disk once new channels have been opened, and old ones closed. This struct +// relies on an atomic file rename property which most widely use file systems +// have. +type MultiFile struct { + // fileName is the file name of the main back up file. + fileName string + + // mainFile is an open handle to the main back up file. + mainFile *os.File + + // tempFileName is the name of the file that we'll use to stage a new + // packed multi-chan backup, and the rename to the main back up file. + tempFileName string + + // tempFile is an open handle to the temp back up file. + tempFile *os.File +} + +// NewMultiFile create a new multi-file instance at the target location on the +// file system. +func NewMultiFile(fileName string) *MultiFile { + + // We'll our temporary backup file in the very same directory as the + // main backup file. + backupFileDir := filepath.Dir(fileName) + tempFileName := filepath.Join( + backupFileDir, DefaultTempBackupFileName, + ) + + return &MultiFile{ + fileName: fileName, + tempFileName: tempFileName, + } +} + +// UpdateAndSwap will attempt write a new temporary backup file to disk with +// the newBackup encoded, then atomically swap (via rename) the old file for +// the new file by updating the name of the new file to the old. +func (b *MultiFile) UpdateAndSwap(newBackup PackedMulti) error { + // If the main backup file isn't set, then we can't proceed. + if b.fileName == "" { + return ErrNoBackupFileExists + } + + // If the old back up file still exists, then we'll delete it before + // proceeding. + if _, err := os.Stat(b.tempFileName); err == nil { + log.Infof("Found old temp backup @ %v, removing before swap", + b.tempFileName) + + err = os.Remove(b.tempFileName) + if err != nil { + return fmt.Errorf("unable to remove temp "+ + "backup file: %v", err) + } + } + + // Now that we know the staging area is clear, we'll create the new + // temporary back up file. + var err error + b.tempFile, err = os.Create(b.tempFileName) + if err != nil { + return err + } + + // With the file created, we'll write the new packed multi backup and + // remove the temporary file all together once this method exits. + _, err = b.tempFile.Write([]byte(newBackup)) + if err != nil { + return err + } + if err := b.tempFile.Sync(); err != nil { + return err + } + defer os.Remove(b.tempFileName) + + log.Infof("Swapping old multi backup file from %v to %v", + b.tempFileName, b.fileName) + + // Finally, we'll attempt to atomically rename the temporary file to + // the main back up file. If this succeeds, then we'll only have a + // single file on disk once this method exits. + return os.Rename(b.tempFileName, b.fileName) +} + +// ExtractMulti attempts to extract the packed multi backup we currently point +// to into an unpacked version. This method will fail if no backup file +// currently exists as the specified location. +func (b *MultiFile) ExtractMulti(keyChain keychain.KeyRing) (*Multi, error) { + var err error + + // If the backup file isn't already set, then we'll attempt to open it + // anew. + if b.mainFile == nil { + // We'll return an error if the main file isn't currently set. + if b.fileName == "" { + return nil, ErrNoBackupFileExists + } + + // Otherwise, we'll open the file to prep for reading the + // contents. + b.mainFile, err = os.Open(b.fileName) + if err != nil { + return nil, err + } + } + + // Before we start to read the file, we'll ensure that the next read + // call will start from the front of the file. + _, err = b.mainFile.Seek(0, 0) + if err != nil { + return nil, err + } + + // With our seek successful, we'll now attempt to read the contents of + // the entire file in one swoop. + multiBytes, err := ioutil.ReadAll(b.mainFile) + if err != nil { + return nil, err + } + + // Finally, we'll attempt to unpack the file and return the unpack + // version to the caller. + packedMulti := PackedMulti(multiBytes) + return packedMulti.Unpack(keyChain) +} diff --git a/chanbackup/backupfile_test.go b/chanbackup/backupfile_test.go new file mode 100644 index 00000000..19733c36 --- /dev/null +++ b/chanbackup/backupfile_test.go @@ -0,0 +1,289 @@ +package chanbackup + +import ( + "bytes" + "fmt" + "io/ioutil" + "math/rand" + "os" + "path/filepath" + "testing" +) + +func makeFakePackedMulti() (PackedMulti, error) { + newPackedMulti := make([]byte, 50) + if _, err := rand.Read(newPackedMulti[:]); err != nil { + return nil, fmt.Errorf("unable to make test backup: %v", err) + } + + return PackedMulti(newPackedMulti), nil +} + +func assertBackupMatches(t *testing.T, filePath string, + currentBackup PackedMulti) { + + t.Helper() + + packedBackup, err := ioutil.ReadFile(filePath) + if err != nil { + t.Fatalf("unable to test file: %v", err) + } + + if !bytes.Equal(packedBackup, currentBackup) { + t.Fatalf("backups don't match after first swap: "+ + "expected %x got %x", packedBackup[:], + currentBackup) + } +} + +func assertFileDeleted(t *testing.T, filePath string) { + t.Helper() + + _, err := os.Stat(filePath) + if err == nil { + t.Fatalf("file %v still exists: ", filePath) + } +} + +// TestUpdateAndSwap test that we're able to properly swap out old backups on +// disk with new ones. Additionally, after a swap operation succeeds, then each +// time we should only have the main backup file on disk, as the temporary file +// has been removed. +func TestUpdateAndSwap(t *testing.T) { + t.Parallel() + + tempTestDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatalf("unable to make temp dir: %v", err) + } + defer os.Remove(tempTestDir) + + testCases := []struct { + fileName string + tempFileName string + + oldTempExists bool + + valid bool + }{ + // Main file name is blank, should fail. + { + fileName: "", + valid: false, + }, + + // Old temporary file still exists, should be removed. Only one + // file should remain. + { + fileName: filepath.Join( + tempTestDir, DefaultBackupFileName, + ), + tempFileName: filepath.Join( + tempTestDir, DefaultTempBackupFileName, + ), + oldTempExists: true, + valid: true, + }, + + // Old temp doesn't exist, should swap out file, only a single + // file remains. + { + fileName: filepath.Join( + tempTestDir, DefaultBackupFileName, + ), + tempFileName: filepath.Join( + tempTestDir, DefaultTempBackupFileName, + ), + valid: true, + }, + } + for i, testCase := range testCases { + // Ensure that all created files are removed at the end of the + // test case. + defer os.Remove(testCase.fileName) + defer os.Remove(testCase.tempFileName) + + backupFile := NewMultiFile(testCase.fileName) + + // To start with, we'll make a random byte slice that'll pose + // as our packed multi backup. + newPackedMulti, err := makeFakePackedMulti() + if err != nil { + t.Fatalf("unable to make test backup: %v", err) + } + + // If the old temporary file is meant to exist, then we'll + // create it now as an empty file. + if testCase.oldTempExists { + _, err := os.Create(testCase.tempFileName) + if err != nil { + t.Fatalf("unable to create temp file: %v", err) + } + + // TODO(roasbeef): mock out fs calls? + } + + // With our backup created, we'll now attempt to swap out this + // backup, for the old one. + err = backupFile.UpdateAndSwap(PackedMulti(newPackedMulti)) + switch { + // If this is a valid test case, and we failed, then we'll + // return an error. + case err != nil && testCase.valid: + t.Fatalf("#%v, unable to swap file: %v", i, err) + + // If this is an invalid test case, and we passed it, then + // we'll return an error. + case err == nil && !testCase.valid: + t.Fatalf("#%v file swap should have failed: %v", i, err) + } + + if !testCase.valid { + continue + } + + // If we read out the file on disk, then it should match + // exactly what we wrote. The temp backup file should also be + // gone. + assertBackupMatches(t, testCase.fileName, newPackedMulti) + assertFileDeleted(t, testCase.tempFileName) + + // Now that we know this is a valid test case, we'll make a new + // packed multi to swap out this current one. + newPackedMulti2, err := makeFakePackedMulti() + if err != nil { + t.Fatalf("unable to make test backup: %v", err) + } + + // We'll then attempt to swap the old version for this new one. + err = backupFile.UpdateAndSwap(PackedMulti(newPackedMulti2)) + if err != nil { + t.Fatalf("unable to swap file: %v", err) + } + + // Once again, the file written on disk should have been + // properly swapped out with the new instance. + assertBackupMatches(t, testCase.fileName, newPackedMulti2) + + // Additionally, we shouldn't be able to find the temp backup + // file on disk, as it should be deleted each time. + assertFileDeleted(t, testCase.tempFileName) + } +} + +func assertMultiEqual(t *testing.T, a, b *Multi) { + + if len(a.StaticBackups) != len(b.StaticBackups) { + t.Fatalf("expected %v backups, got %v", len(a.StaticBackups), + len(b.StaticBackups)) + } + + for i := 0; i < len(a.StaticBackups); i++ { + assertSingleEqual(t, a.StaticBackups[i], b.StaticBackups[i]) + } +} + +// TestExtractMulti tests that given a valid packed multi file on disk, we're +// able to read it multiple times repeatedly. +func TestExtractMulti(t *testing.T) { + t.Parallel() + + keyRing := &mockKeyRing{} + + // First, as prep, we'll create a single chan backup, then pack that + // fully into a multi backup. + channel, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to gen chan: %v", err) + } + + singleBackup := NewSingle(channel, nil) + + var b bytes.Buffer + unpackedMulti := Multi{ + StaticBackups: []Single{singleBackup}, + } + err = unpackedMulti.PackToWriter(&b, keyRing) + if err != nil { + t.Fatalf("unable to pack to writer: %v", err) + } + + packedMulti := PackedMulti(b.Bytes()) + + // Finally, we'll make a new temporary file, then write out the packed + // multi directly to to it. + tempFile, err := ioutil.TempFile("", "") + if err != nil { + t.Fatalf("unable to create temp file: %v", err) + } + defer os.Remove(tempFile.Name()) + + _, err = tempFile.Write(packedMulti) + if err != nil { + t.Fatalf("unable to write temp file: %v", err) + } + if err := tempFile.Sync(); err != nil { + t.Fatalf("unable to sync temp file: %v", err) + } + + testCases := []struct { + fileName string + pass bool + }{ + // Main file not read, file name not present. + { + fileName: "", + pass: false, + }, + + // Main file not read, file name is there, but file doesn't + // exist. + { + fileName: "kek", + pass: false, + }, + + // Main file not read, should be able to read multiple times. + { + fileName: tempFile.Name(), + pass: true, + }, + } + for i, testCase := range testCases { + // First, we'll make our backup file with the specified name. + backupFile := NewMultiFile(testCase.fileName) + + // With our file made, we'll now attempt to read out the + // multi-file. + freshUnpackedMulti, err := backupFile.ExtractMulti(keyRing) + switch { + // If this is a valid test case, and we failed, then we'll + // return an error. + case err != nil && testCase.pass: + t.Fatalf("#%v, unable to extract file: %v", i, err) + + // If this is an invalid test case, and we passed it, then + // we'll return an error. + case err == nil && !testCase.pass: + t.Fatalf("#%v file extraction should have "+ + "failed: %v", i, err) + } + + if !testCase.pass { + continue + } + + // We'll now ensure that the unpacked multi we read is + // identical to the one we wrote out above. + assertMultiEqual(t, &unpackedMulti, freshUnpackedMulti) + + // We should also be able to read the file again, as we have an + // existing handle to it. + freshUnpackedMulti, err = backupFile.ExtractMulti(keyRing) + if err != nil { + t.Fatalf("unable to unpack multi: %v", err) + } + + assertMultiEqual(t, &unpackedMulti, freshUnpackedMulti) + } +} diff --git a/chanbackup/crypto.go b/chanbackup/crypto.go new file mode 100644 index 00000000..8fdb46f6 --- /dev/null +++ b/chanbackup/crypto.go @@ -0,0 +1,140 @@ +package chanbackup + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + "io/ioutil" + + "github.com/lightningnetwork/lnd/keychain" + "golang.org/x/crypto/chacha20poly1305" +) + +// TODO(roasbeef): interface in front of? + +// baseEncryptionKeyLoc is the KeyLocator that we'll use to derive the base +// encryption key used for encrypting all static channel backups. We use this +// to then derive the actual key that we'll use for encryption. We do this +// rather than using the raw key, as we assume that we can't obtain the raw +// keys, and we don't want to require that the HSM know our target cipher for +// encryption. +// +// TODO(roasbeef): possibly unique encrypt? +var baseEncryptionKeyLoc = keychain.KeyLocator{ + Family: keychain.KeyFamilyStaticBackup, + Index: 0, +} + +// genEncryptionKey derives the key that we'll use to encrypt all of our static +// channel backups. The key itself, is the sha2 of a base key that we get from +// the keyring. We derive the key this way as we don't force the HSM (or any +// future abstractions) to be able to derive and know of the cipher that we'll +// use within our protocol. +func genEncryptionKey(keyRing keychain.KeyRing) ([]byte, error) { + // key = SHA256(baseKey) + baseKey, err := keyRing.DeriveKey( + baseEncryptionKeyLoc, + ) + if err != nil { + return nil, err + } + + encryptionKey := sha256.Sum256( + baseKey.PubKey.SerializeCompressed(), + ) + + // TODO(roasbeef): throw back in ECDH? + + return encryptionKey[:], nil +} + +// encryptPayloadToWriter attempts to write the set of bytes contained within +// the passed byes.Buffer into the passed io.Writer in an encrypted form. We +// use a 24-byte chachapoly AEAD instance with a randomized nonce that's +// pre-pended to the final payload and used as associated data in the AEAD. We +// use the passed keyRing to generate the encryption key, see genEncryptionKey +// for further details. +func encryptPayloadToWriter(payload bytes.Buffer, w io.Writer, + keyRing keychain.KeyRing) error { + + // First, we'll derive the key that we'll use to encrypt the payload + // for safe storage without giving away the details of any of our + // channels. The final operation is: + // + // key = SHA256(baseKey) + encryptionKey, err := genEncryptionKey(keyRing) + if err != nil { + return err + } + + // Before encryption, we'll initialize our cipher with the target + // encryption key, and also read out our random 24-byte nonce we use + // for encryption. Note that we use NewX, not New, as the latter + // version requires a 12-byte nonce, not a 24-byte nonce. + cipher, err := chacha20poly1305.NewX(encryptionKey) + if err != nil { + return err + } + var nonce [chacha20poly1305.NonceSizeX]byte + if _, err := rand.Read(nonce[:]); err != nil { + return err + } + + // Finally, we encrypted the final payload, and write out our + // ciphertext with nonce pre-pended. + ciphertext := cipher.Seal(nil, nonce[:], payload.Bytes(), nonce[:]) + + if _, err := w.Write(nonce[:]); err != nil { + return err + } + if _, err := w.Write(ciphertext); err != nil { + return err + } + + return nil +} + +// decryptPayloadFromReader attempts to decrypt the encrypted bytes within the +// passed io.Reader instance using the key derived from the passed keyRing. For +// further details regarding the key derivation protocol, see the +// genEncryptionKey method. +func decryptPayloadFromReader(payload io.Reader, + keyRing keychain.KeyRing) ([]byte, error) { + + // First, we'll re-generate the encryption key that we use for all the + // SCBs. + encryptionKey, err := genEncryptionKey(keyRing) + if err != nil { + return nil, err + } + + // Next, we'll read out the entire blob as we need to isolate the nonce + // from the rest of the ciphertext. + packedBackup, err := ioutil.ReadAll(payload) + if err != nil { + return nil, err + } + if len(packedBackup) < chacha20poly1305.NonceSizeX { + return nil, fmt.Errorf("payload size too small, must be at "+ + "least %v bytes", chacha20poly1305.NonceSizeX) + } + + nonce := packedBackup[:chacha20poly1305.NonceSizeX] + ciphertext := packedBackup[chacha20poly1305.NonceSizeX:] + + // Now that we have the cipher text and the nonce separated, we can go + // ahead and decrypt the final blob so we can properly serialized the + // SCB. + cipher, err := chacha20poly1305.NewX(encryptionKey) + if err != nil { + return nil, err + } + plaintext, err := cipher.Open(nil, nonce, ciphertext, nonce) + if err != nil { + return nil, err + } + + return plaintext, nil +} diff --git a/chanbackup/crypto_test.go b/chanbackup/crypto_test.go new file mode 100644 index 00000000..6b4b27fe --- /dev/null +++ b/chanbackup/crypto_test.go @@ -0,0 +1,156 @@ +package chanbackup + +import ( + "bytes" + "fmt" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/keychain" +) + +var ( + testWalletPrivKey = []byte{ + 0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf, + 0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9, + 0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f, + 0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90, + } +) + +type mockKeyRing struct { + fail bool +} + +func (m *mockKeyRing) DeriveNextKey(keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) { + return keychain.KeyDescriptor{}, nil +} +func (m *mockKeyRing) DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) { + if m.fail { + return keychain.KeyDescriptor{}, fmt.Errorf("fail") + } + + _, pub := btcec.PrivKeyFromBytes(btcec.S256(), testWalletPrivKey) + return keychain.KeyDescriptor{ + PubKey: pub, + }, nil +} + +// TestEncryptDecryptPayload tests that given a static key, we're able to +// properly decrypt and encrypted payload. We also test that we'll reject a +// ciphertext that has been modified. +func TestEncryptDecryptPayload(t *testing.T) { + t.Parallel() + + payloadCases := []struct { + // plaintext is the string that we'll be encrypting. + plaintext []byte + + // mutator allows a test case to modify the ciphertext before + // we attempt to decrypt it. + mutator func(*[]byte) + + // valid indicates if this test should pass or fail. + valid bool + }{ + // Proper payload, should decrypt. + { + plaintext: []byte("payload test plain text"), + mutator: nil, + valid: true, + }, + + // Mutator modifies cipher text, shouldn't decrypt. + { + plaintext: []byte("payload test plain text"), + mutator: func(p *[]byte) { + // Flip a byte in the payload to render it invalid. + (*p)[0] ^= 1 + }, + valid: false, + }, + + // Cipher text is too small, shouldn't decrypt. + { + plaintext: []byte("payload test plain text"), + mutator: func(p *[]byte) { + // Modify the cipher text to be zero length. + *p = []byte{} + }, + valid: false, + }, + } + + keyRing := &mockKeyRing{} + + for i, payloadCase := range payloadCases { + var cipherBuffer bytes.Buffer + + // First, we'll encrypt the passed payload with our scheme. + payloadReader := bytes.NewBuffer(payloadCase.plaintext) + err := encryptPayloadToWriter( + *payloadReader, &cipherBuffer, keyRing, + ) + if err != nil { + t.Fatalf("unable encrypt paylaod: %v", err) + } + + // If we have a mutator, then we'll wrong the mutator over the + // cipher text, then reset the main buffer and re-write the new + // cipher text. + if payloadCase.mutator != nil { + cipherText := cipherBuffer.Bytes() + + payloadCase.mutator(&cipherText) + + cipherBuffer.Reset() + cipherBuffer.Write(cipherText) + } + + plaintext, err := decryptPayloadFromReader(&cipherBuffer, keyRing) + + switch { + // If this was meant to be a valid decryption, but we failed, + // then we'll return an error. + case err != nil && payloadCase.valid: + t.Fatalf("unable to decrypt valid payload case %v", i) + + // If this was meant to be an invalid decryption, and we didn't + // fail, then we'll return an error. + case err == nil && !payloadCase.valid: + t.Fatalf("payload was invalid yet was able to decrypt") + } + + // Only if this case was mean to be valid will we ensure the + // resulting decrypted plaintext matches the original input. + if payloadCase.valid && + !bytes.Equal(plaintext, payloadCase.plaintext) { + t.Fatalf("#%v: expected %v, got %v: ", i, + payloadCase.plaintext, plaintext) + } + } +} + +// TestInvalidKeyEncryption tests that encryption fails if we're unable to +// obtain a valid key. +func TestInvalidKeyEncryption(t *testing.T) { + t.Parallel() + + var b bytes.Buffer + err := encryptPayloadToWriter(b, &b, &mockKeyRing{true}) + if err == nil { + t.Fatalf("expected error due to fail key gen") + } +} + +// TestInvalidKeyDecrytion tests that decryption fails if we're unable to +// obtain a valid key. +func TestInvalidKeyDecrytion(t *testing.T) { + t.Parallel() + + var b bytes.Buffer + _, err := decryptPayloadFromReader(&b, &mockKeyRing{true}) + if err == nil { + t.Fatalf("expected error due to fail key gen") + } +} diff --git a/chanbackup/log.go b/chanbackup/log.go new file mode 100644 index 00000000..730fbc9e --- /dev/null +++ b/chanbackup/log.go @@ -0,0 +1,45 @@ +package chanbackup + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger("CHBU", nil)) +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} + +// logClosure is used to provide a closure over expensive logging operations so +// don't have to be performed when the logging level doesn't warrant it. +type logClosure func() string + +// String invokes the underlying function and returns the result. +func (c logClosure) String() string { + return c() +} + +// newLogClosure returns a new closure over a function that returns a string +// which itself provides a Stringer interface so that it can be used with the +// logging system. +func newLogClosure(c func() string) logClosure { + return logClosure(c) +} diff --git a/chanbackup/multi.go b/chanbackup/multi.go new file mode 100644 index 00000000..d77be204 --- /dev/null +++ b/chanbackup/multi.go @@ -0,0 +1,176 @@ +package chanbackup + +import ( + "bytes" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" +) + +// MultiBackupVersion denotes the version of the multi channel static channel +// backup. Based on this version, we know how to encode/decode packed/unpacked +// versions of multi backups. +type MultiBackupVersion byte + +const ( + // DefaultMultiVersion is the default version of the multi channel + // backup. The serialized format for this version is simply: version || + // numBackups || SCBs... + DefaultMultiVersion = 0 +) + +// Multi is a form of static channel backup that is amenable to being +// serialized in a single file. Rather than a series of ciphertexts, a +// multi-chan backup is a single ciphertext of all static channel backups +// concatenated. This form factor gives users a single blob that they can use +// to safely copy/obtain at anytime to backup their channels. +type Multi struct { + // Version is the version that should be observed when attempting to + // pack the multi backup. + Version MultiBackupVersion + + // StaticBackups is the set of single channel backups that this multi + // backup is comprised of. + StaticBackups []Single +} + +// PackToWriter packs (encrypts+serializes) the target set of static channel +// backups into a single AEAD ciphertext into the passed io.Writer. This is the +// opposite of UnpackFromReader. The plaintext form of a multi-chan backup is +// the following: a 4 byte integer denoting the number of serialized static +// channel backups serialized, a series of serialized static channel backups +// concatenated. To pack this payload, we then apply our chacha20 AEAD to the +// entire payload, using the 24-byte nonce as associated data. +func (m Multi) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error { + // The only version that we know how to pack atm is version 0. Attempts + // to pack any other version will result in an error. + switch m.Version { + case DefaultMultiVersion: + break + + default: + return fmt.Errorf("unable to pack unknown multi-version "+ + "of %v", m.Version) + } + + var multiBackupBuffer bytes.Buffer + + // First, we'll write out the version of this multi channel baackup. + err := lnwire.WriteElements(&multiBackupBuffer, byte(m.Version)) + if err != nil { + return err + } + + // Now that we've written out the version of this multi-pack format, + // we'll now write the total number of backups to expect after this + // point. + numBackups := uint32(len(m.StaticBackups)) + err = lnwire.WriteElements(&multiBackupBuffer, numBackups) + if err != nil { + return err + } + + // Next, we'll serialize the raw plaintext version of each of the + // backup into the intermediate buffer. + for _, chanBackup := range m.StaticBackups { + err := chanBackup.Serialize(&multiBackupBuffer) + if err != nil { + return fmt.Errorf("unable to serialize backup "+ + "for %v: %v", chanBackup.FundingOutpoint, err) + } + } + + // With the plaintext multi backup assembled, we'll now encrypt it + // directly to the passed writer. + return encryptPayloadToWriter(multiBackupBuffer, w, keyRing) +} + +// UnpackFromReader attempts to unpack (decrypt+deserialize) a packed +// multi-chan backup form the passed io.Reader. If we're unable to decrypt the +// any portion of the multi-chan backup, an error will be returned. +func (m *Multi) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error { + // We'll attempt to read the entire packed backup, and also decrypt it + // using the passed key ring which is expected to be able to derive the + // encryption keys. + plaintextBackup, err := decryptPayloadFromReader(r, keyRing) + if err != nil { + return err + } + backupReader := bytes.NewReader(plaintextBackup) + + // Now that we've decrypted the payload successfully, we can parse out + // each of the individual static channel backups. + + // First, we'll need to read the version of this multi-back up so we + // can know how to unpack each of the individual SCB's. + var multiVersion byte + err = lnwire.ReadElements(backupReader, &multiVersion) + if err != nil { + return err + } + + m.Version = MultiBackupVersion(multiVersion) + switch m.Version { + + // The default version is simply a set of serialized SCB's with the + // number of total SCB's prepended to the front of the byte slice. + case DefaultMultiVersion: + // First, we'll need to read out the total number of backups + // that've been serialized into this multi-chan backup. Each + // backup is the same size, so we can continue until we've + // parsed out everything. + var numBackups uint32 + err = lnwire.ReadElements(backupReader, &numBackups) + if err != nil { + return err + } + + // We'll continue to parse out each backup until we've read all + // that was indicated from the length prefix. + for ; numBackups != 0; numBackups-- { + // Attempt to parse out the net static channel backup, + // if it's been malformed, then we'll return with an + // error + var chanBackup Single + err := chanBackup.Deserialize(backupReader) + if err != nil { + return err + } + + // Collect the next valid chan backup into the main + // multi backup slice. + m.StaticBackups = append(m.StaticBackups, chanBackup) + } + + default: + return fmt.Errorf("unable to unpack unknown multi-version "+ + "of %v", multiVersion) + } + + return nil +} + +// TODO(roasbeef): new key ring interface? +// * just returns key given params? + +// PackedMulti represents a raw fully packed (serialized+encrypted) +// multi-channel static channel backup. +type PackedMulti []byte + +// Unpack attempts to unpack (decrypt+desrialize) the target packed +// multi-channel back up. If we're unable to fully unpack this back, then an +// error will be returned. +func (p *PackedMulti) Unpack(keyRing keychain.KeyRing) (*Multi, error) { + var m Multi + + packedReader := bytes.NewReader(*p) + if err := m.UnpackFromReader(packedReader, keyRing); err != nil { + return nil, err + } + + return &m, nil +} + +// TODO(roasbsef): fuzz parsing diff --git a/chanbackup/multi_test.go b/chanbackup/multi_test.go new file mode 100644 index 00000000..a6317e09 --- /dev/null +++ b/chanbackup/multi_test.go @@ -0,0 +1,159 @@ +package chanbackup + +import ( + "bytes" + "net" + "testing" +) + +// TestMultiPackUnpack... +func TestMultiPackUnpack(t *testing.T) { + t.Parallel() + + var multi Multi + numSingles := 10 + originalSingles := make([]Single, 0, numSingles) + for i := 0; i < numSingles; i++ { + channel, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to gen channel: %v", err) + } + + single := NewSingle(channel, []net.Addr{addr1, addr2}) + + originalSingles = append(originalSingles, single) + multi.StaticBackups = append(multi.StaticBackups, single) + } + + keyRing := &mockKeyRing{} + + versionTestCases := []struct { + // version is the pack/unpack version that we should use to + // decode/encode the final SCB. + version MultiBackupVersion + + // valid tests us if this test case should pass or not. + valid bool + }{ + // The default version, should pack/unpack with no problem. + { + version: DefaultSingleVersion, + valid: true, + }, + + // A non-default version, atm this should result in a failure. + { + version: 99, + valid: false, + }, + } + for i, versionCase := range versionTestCases { + multi.Version = versionCase.version + + var b bytes.Buffer + err := multi.PackToWriter(&b, keyRing) + switch { + // If this is a valid test case, and we failed, then we'll + // return an error. + case err != nil && versionCase.valid: + t.Fatalf("#%v, unable to pack multi: %v", i, err) + + // If this is an invalid test case, and we passed it, then + // we'll return an error. + case err == nil && !versionCase.valid: + t.Fatalf("#%v got nil error for invalid pack: %v", + i, err) + } + + // If this is a valid test case, then we'll continue to ensure + // we can unpack it, and also that if we mutate the packed + // version, then we trigger an error. + if versionCase.valid { + var unpackedMulti Multi + err = unpackedMulti.UnpackFromReader(&b, keyRing) + if err != nil { + t.Fatalf("#%v unable to unpack multi: %v", + i, err) + } + + // First, we'll ensure that the unpacked version of the + // packed multi is the same as the original set. + if len(originalSingles) != + len(unpackedMulti.StaticBackups) { + t.Fatalf("expected %v singles, got %v", + len(originalSingles), + len(unpackedMulti.StaticBackups)) + } + for i := 0; i < numSingles; i++ { + assertSingleEqual( + t, originalSingles[i], + unpackedMulti.StaticBackups[i], + ) + } + + // Next, we'll make a fake packed multi, it'll have an + // unknown version relative to what's implemented atm. + var fakePackedMulti bytes.Buffer + fakeRawMulti := bytes.NewBuffer( + bytes.Repeat([]byte{99}, 20), + ) + err := encryptPayloadToWriter( + *fakeRawMulti, &fakePackedMulti, keyRing, + ) + if err != nil { + t.Fatalf("unable to pack fake multi; %v", err) + } + + // We should reject this fake multi as it contains an + // unknown version. + err = unpackedMulti.UnpackFromReader( + &fakePackedMulti, keyRing, + ) + if err == nil { + t.Fatalf("#%v unpack with unknown version "+ + "should have failed", i) + } + } + } +} + +// TestPackedMultiUnpack tests that we're able to properly unpack a typed +// packed multi. +func TestPackedMultiUnpack(t *testing.T) { + t.Parallel() + + keyRing := &mockKeyRing{} + + // First, we'll make a new unpacked multi with a random channel. + testChannel, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to gen random channel: %v", err) + } + var multi Multi + multi.StaticBackups = append( + multi.StaticBackups, NewSingle(testChannel, nil), + ) + + // Now that we have our multi, we'll pack it into a new buffer. + var b bytes.Buffer + if err := multi.PackToWriter(&b, keyRing); err != nil { + t.Fatalf("unable to pack multi: %v", err) + } + + // We should be able to properly unpack this typed packed multi. + packedMulti := PackedMulti(b.Bytes()) + unpackedMulti, err := packedMulti.Unpack(keyRing) + if err != nil { + t.Fatalf("unable to unpack multi: %v", err) + } + + // Finally, the versions should match, and the unpacked singles also + // identical. + if multi.Version != unpackedMulti.Version { + t.Fatalf("version mismatch: expected %v got %v", + multi.Version, unpackedMulti.Version) + } + assertSingleEqual( + t, multi.StaticBackups[0], unpackedMulti.StaticBackups[0], + ) +} diff --git a/chanbackup/pubsub.go b/chanbackup/pubsub.go new file mode 100644 index 00000000..a4e41857 --- /dev/null +++ b/chanbackup/pubsub.go @@ -0,0 +1,247 @@ +package chanbackup + +import ( + "bytes" + "net" + "sync" + "sync/atomic" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/keychain" +) + +// Swapper is an interface that allows the chanbackup.SubSwapper to update the +// main multi backup location once it learns of new channels or that prior +// channels have been closed. +type Swapper interface { + // UpdateAndSwap attempts to atomically update the main multi back up + // file location with the new fully packed multi-channel backup. + UpdateAndSwap(newBackup PackedMulti) error +} + +// ChannelWithAddrs bundles an open channel along with all the addresses for +// the channel peer. +// +// TODO(roasbeef): use channel shell instead? +type ChannelWithAddrs struct { + *channeldb.OpenChannel + + // Addrs is the set of addresses that we can use to reach the target + // peer. + Addrs []net.Addr +} + +// ChannelEvent packages a new update of new channels since subscription, and +// channels that have been opened since prior channel event. +type ChannelEvent struct { + // ClosedChans are the set of channels that have been closed since the + // last event. + ClosedChans []wire.OutPoint + + // NewChans is the set of channels that have been opened since the last + // event. + NewChans []ChannelWithAddrs +} + +// ChannelSubscription represents an intent to be notified of any updates to +// the primary channel state. +type ChannelSubscription struct { + // ChanUpdates is a read-only channel that will be sent upon once the + // primary channel state is updated. + ChanUpdates <-chan ChannelEvent + + // Cancel is a closure that allows the caller to cancel their + // subscription and free up any resources allocated. + Cancel func() +} + +// ChannelNotifier represents a system that allows the chanbackup.SubSwapper to +// be notified of any changes to the primary channel state. +type ChannelNotifier interface { + // SubscribeChans requests a new channel subscription relative to the + // initial set of known channels. We use the knownChans as a + // synchronization point to ensure that the chanbackup.SubSwapper does + // not miss any channel open or close events in the period between when + // it's created, and when it requests the channel subscription. + SubscribeChans(map[wire.OutPoint]struct{}) (*ChannelSubscription, error) +} + +// SubSwapper subscribes to new updates to the open channel state, and then +// swaps out the on-disk channel backup state in response. This sub-system +// that will ensure that the multi chan backup file on disk will always be +// updated with the latest channel back up state. We'll receive new +// opened/closed channels from the ChannelNotifier, then use the Swapper to +// update the file state on disk with the new set of open channels. This can +// be used to implement a system that always keeps the multi-chan backup file +// on disk in a consistent state for safety purposes. +// +// TODO(roasbeef): better name lol +type SubSwapper struct { + started uint32 + stopped uint32 + + // backupState are the set of SCBs for all open channels we know of. + backupState map[wire.OutPoint]Single + + // chanEvents is an active subscription to receive new channel state + // over. + chanEvents *ChannelSubscription + + // keyRing is the main key ring that will allow us to pack the new + // multi backup. + keyRing keychain.KeyRing + + Swapper + + quit chan struct{} + wg sync.WaitGroup +} + +// NewSubSwapper creates a new instance of the SubSwapper given the starting +// set of channels, and the required interfaces to be notified of new channel +// updates, pack a multi backup, and swap the current best backup from its +// storage location. +func NewSubSwapper(startingChans []Single, chanNotifier ChannelNotifier, + keyRing keychain.KeyRing, backupSwapper Swapper) (*SubSwapper, error) { + + // First, we'll subscribe to the latest set of channel updates given + // the set of channels we already know of. + knownChans := make(map[wire.OutPoint]struct{}) + for _, chanBackup := range startingChans { + knownChans[chanBackup.FundingOutpoint] = struct{}{} + } + chanEvents, err := chanNotifier.SubscribeChans(knownChans) + if err != nil { + return nil, err + } + + // Next, we'll construct our own backup state so we can add/remove + // channels that have been opened and closed. + backupState := make(map[wire.OutPoint]Single) + for _, chanBackup := range startingChans { + backupState[chanBackup.FundingOutpoint] = chanBackup + } + + return &SubSwapper{ + backupState: backupState, + chanEvents: chanEvents, + keyRing: keyRing, + Swapper: backupSwapper, + quit: make(chan struct{}), + }, nil +} + +// Start starts the chanbackup.SubSwapper. +func (s *SubSwapper) Start() error { + if !atomic.CompareAndSwapUint32(&s.started, 0, 1) { + return nil + } + + log.Infof("Starting chanbackup.SubSwapper") + + s.wg.Add(1) + go s.backupUpdater() + + return nil +} + +// Stop signals the SubSwapper to being a graceful shutdown. +func (s *SubSwapper) Stop() error { + if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) { + return nil + } + + log.Infof("Stopping chanbackup.SubSwapper") + + close(s.quit) + s.wg.Wait() + + return nil +} + +// backupFileUpdater is the primary goroutine of the SubSwapper which is +// responsible for listening for changes to the channel, and updating the +// persistent multi backup state with a new packed multi of the latest channel +// state. +func (s *SubSwapper) backupUpdater() { + // Ensure that once we exit, we'll cancel our active channel + // subscription. + defer s.chanEvents.Cancel() + defer s.wg.Done() + + log.Debugf("SubSwapper's backupUpdater is active!") + + for { + select { + // The channel state has been modified! We'll evaluate all + // changes, and swap out the old packed multi with a new one + // with the latest channel state. + case chanUpdate := <-s.chanEvents.ChanUpdates: + oldStateSize := len(s.backupState) + + // For all new open channels, we'll create a new SCB + // given the required information. + for _, newChan := range chanUpdate.NewChans { + log.Debugf("Adding chanenl %v to backup state", + newChan.FundingOutpoint) + + s.backupState[newChan.FundingOutpoint] = NewSingle( + newChan.OpenChannel, newChan.Addrs, + ) + } + + // For all closed channels, we'll remove the prior + // backup state. + for _, closedChan := range chanUpdate.ClosedChans { + log.Debugf("Removing channel %v from backup "+ + "state", newLogClosure(func() string { + return closedChan.String() + })) + + delete(s.backupState, closedChan) + } + + newStateSize := len(s.backupState) + + // With our updated channel state obtained, we'll + // create a new multi from our series of singles. + var newMulti Multi + for _, backup := range s.backupState { + newMulti.StaticBackups = append( + newMulti.StaticBackups, backup, + ) + } + + // Now that our multi has been assembled, we'll attempt + // to pack (encrypt+encode) the new channel state to + // our target reader. + var b bytes.Buffer + err := newMulti.PackToWriter(&b, s.keyRing) + if err != nil { + log.Errorf("unable to pack multi backup: %v", + err) + continue + } + + log.Infof("Updating on-disk multi SCB backup: "+ + "num_old_chans=%v, num_new_chans=%v", + oldStateSize, newStateSize) + + // Finally, we'll swap out the old backup for this new + // one in a single atomic step. + err = s.Swapper.UpdateAndSwap( + PackedMulti(b.Bytes()), + ) + if err != nil { + log.Errorf("unable to update multi "+ + "backup: %v", err) + continue + } + + // Exit at once if a quit signal is detected. + case <-s.quit: + return + } + } +} diff --git a/chanbackup/pubsub_test.go b/chanbackup/pubsub_test.go new file mode 100644 index 00000000..f571c8b7 --- /dev/null +++ b/chanbackup/pubsub_test.go @@ -0,0 +1,234 @@ +package chanbackup + +import ( + "fmt" + "reflect" + "testing" + "time" + + "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/keychain" +) + +type mockSwapper struct { + fail bool + + swaps chan PackedMulti +} + +func newMockSwapper() *mockSwapper { + return &mockSwapper{ + swaps: make(chan PackedMulti), + } +} + +func (m *mockSwapper) UpdateAndSwap(newBackup PackedMulti) error { + if m.fail { + return fmt.Errorf("fail") + } + + m.swaps <- newBackup + + return nil +} + +type mockChannelNotifier struct { + fail bool + + chanEvents chan ChannelEvent +} + +func newMockChannelNotifier() *mockChannelNotifier { + return &mockChannelNotifier{ + chanEvents: make(chan ChannelEvent), + } +} + +func (m *mockChannelNotifier) SubscribeChans(chans map[wire.OutPoint]struct{}) ( + *ChannelSubscription, error) { + + if m.fail { + return nil, fmt.Errorf("fail") + } + + return &ChannelSubscription{ + ChanUpdates: m.chanEvents, + Cancel: func() { + }, + }, nil +} + +// TestNewSubSwapperSubscribeFail tests that if we're unable to obtain a +// channel subscription, then the entire sub-swapper will fail to start. +func TestNewSubSwapperSubscribeFail(t *testing.T) { + t.Parallel() + + keyRing := &mockKeyRing{} + + var swapper mockSwapper + chanNotifier := mockChannelNotifier{ + fail: true, + } + + _, err := NewSubSwapper(nil, &chanNotifier, keyRing, &swapper) + if err == nil { + t.Fatalf("expected fail due to lack of subscription") + } +} + +func assertExpectedBackupSwap(t *testing.T, swapper *mockSwapper, + subSwapper *SubSwapper, keyRing keychain.KeyRing, + expectedChanSet map[wire.OutPoint]Single) { + + t.Helper() + + select { + case newPackedMulti := <-swapper.swaps: + // If we unpack the new multi, then we should find all the old + // channels, and also the new channel included and any deleted + // channel omitted.. + newMulti, err := newPackedMulti.Unpack(keyRing) + if err != nil { + t.Fatalf("unable to unpack multi: %v", err) + } + + // Ensure that once unpacked, the current backup has the + // expected number of Singles. + if len(newMulti.StaticBackups) != len(expectedChanSet) { + t.Fatalf("new backup wasn't included: expected %v "+ + "backups have %v", len(expectedChanSet), + len(newMulti.StaticBackups)) + } + + // We should also find all the old and new channels in this new + // backup. + for _, backup := range newMulti.StaticBackups { + _, ok := expectedChanSet[backup.FundingOutpoint] + if !ok { + t.Fatalf("didn't find backup in original set: %v", + backup.FundingOutpoint) + } + } + + // The internal state of the sub-swapper should also be one + // larger. + if !reflect.DeepEqual(expectedChanSet, subSwapper.backupState) { + t.Fatalf("backup set doesn't match: expected %v got %v", + spew.Sdump(expectedChanSet), + spew.Sdump(subSwapper.backupState)) + } + + case <-time.After(time.Second * 5): + t.Fatalf("update swapper didn't swap out multi") + } +} + +// TestSubSwapperIdempotentStartStop tests that calling the Start/Stop methods +// multiple time is permitted. +func TestSubSwapperIdempotentStartStop(t *testing.T) { + t.Parallel() + + keyRing := &mockKeyRing{} + + var ( + swapper mockSwapper + chanNotifier mockChannelNotifier + ) + + subSwapper, err := NewSubSwapper(nil, &chanNotifier, keyRing, &swapper) + if err != nil { + t.Fatalf("unable to init subSwapper: %v", err) + } + + subSwapper.Start() + subSwapper.Start() + + subSwapper.Stop() + subSwapper.Stop() +} + +// TestSubSwapperUpdater tests that the SubSwapper will properly swap out +// new/old channels within the channel set, and notify the swapper to update +// the master multi file backup. +func TestSubSwapperUpdater(t *testing.T) { + t.Parallel() + + keyRing := &mockKeyRing{} + chanNotifier := newMockChannelNotifier() + swapper := newMockSwapper() + + // First, we'll start out by creating a channels set for the initial + // set of channels known to the sub-swapper. + const numStartingChans = 3 + initialChanSet := make([]Single, 0, numStartingChans) + backupSet := make(map[wire.OutPoint]Single) + for i := 0; i < numStartingChans; i++ { + channel, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to make test chan: %v", err) + } + + single := NewSingle(channel, nil) + + backupSet[channel.FundingOutpoint] = single + initialChanSet = append(initialChanSet, single) + } + + // With our channel set created, we'll make a fresh sub swapper + // instance to begin our test. + subSwapper, err := NewSubSwapper( + initialChanSet, chanNotifier, keyRing, swapper, + ) + if err != nil { + t.Fatalf("unable to make swapper: %v", err) + } + if err := subSwapper.Start(); err != nil { + t.Fatalf("unable to start sub swapper: %v", err) + } + defer subSwapper.Stop() + + // Now that the sub-swapper is active, we'll notify to add a brand new + // channel to the channel state. + newChannel, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to create new chan: %v", err) + } + + // With the new channel created, we'll send a new update to the main + // goroutine telling it about this new channel. + select { + case chanNotifier.chanEvents <- ChannelEvent{ + NewChans: []ChannelWithAddrs{ + { + OpenChannel: newChannel, + }, + }, + }: + case <-time.After(time.Second * 5): + t.Fatalf("update swapper didn't read new channel: %v", err) + } + + backupSet[newChannel.FundingOutpoint] = NewSingle(newChannel, nil) + + // At this point, the sub-swapper should now have packed a new multi, + // and then sent it to the swapper so the back up can be updated. + assertExpectedBackupSwap(t, swapper, subSwapper, keyRing, backupSet) + + // We'll now trigger an update to remove an existing channel. + chanToDelete := initialChanSet[0].FundingOutpoint + select { + case chanNotifier.chanEvents <- ChannelEvent{ + ClosedChans: []wire.OutPoint{chanToDelete}, + }: + + case <-time.After(time.Second * 5): + t.Fatalf("update swapper didn't read new channel: %v", err) + } + + delete(backupSet, chanToDelete) + + // Verify that the new set of backups, now has one less after the + // sub-swapper switches the new set with the old. + assertExpectedBackupSwap(t, swapper, subSwapper, keyRing, backupSet) +} diff --git a/chanbackup/recover.go b/chanbackup/recover.go new file mode 100644 index 00000000..8619c880 --- /dev/null +++ b/chanbackup/recover.go @@ -0,0 +1,114 @@ +package chanbackup + +import ( + "net" + + "github.com/btcsuite/btcd/btcec" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/keychain" +) + +// ChannelRestorer is an interface that allows the Recover method to map the +// set of single channel backups into a set of "channel shells" and store these +// persistently on disk. The channel shell should contain all the information +// needed to execute the data loss recovery protocol once the channel peer is +// connected to. +type ChannelRestorer interface { + // RestoreChansFromSingles attempts to map the set of single channel + // backups to channel shells that will be stored persistently. Once + // these shells have been stored on disk, we'll be able to connect to + // the channel peer an execute the data loss recovery protocol. + RestoreChansFromSingles(...Single) error +} + +// PeerConnector is an interface that allows the Recover method to connect to +// the target node given the set of possible addresses. +type PeerConnector interface { + // ConnectPeer attempts to connect to the target node at the set of + // available addresses. Once this method returns with a non-nil error, + // the connector should attempt to persistently connect to the target + // peer in the background as a persistent attempt. + ConnectPeer(node *btcec.PublicKey, addrs []net.Addr) error +} + +// Recover attempts to recover the static channel state from a set of static +// channel backups. If successfully, the database will be populated with a +// series of "shell" channels. These "shell" channels cannot be used to operate +// the channel as normal, but instead are meant to be used to enter the data +// loss recovery phase, and recover the settled funds within +// the channel. In addition a LinkNode will be created for each new peer as +// well, in order to expose the addressing information required to locate to +// and connect to each peer in order to initiate the recovery protocol. +func Recover(backups []Single, restorer ChannelRestorer, + peerConnector PeerConnector) error { + + for _, backup := range backups { + log.Infof("Restoring ChannelPoint(%v) to disk: ", + backup.FundingOutpoint) + + err := restorer.RestoreChansFromSingles(backup) + if err != nil { + return err + } + + log.Infof("Attempting to connect to node=%x (addrs=%v) to "+ + "restore ChannelPoint(%v)", + backup.RemoteNodePub.SerializeCompressed(), + newLogClosure(func() string { + return spew.Sdump(backup.Addresses) + }), backup.FundingOutpoint) + + err = peerConnector.ConnectPeer( + backup.RemoteNodePub, backup.Addresses, + ) + if err != nil { + return err + } + + // TODO(roasbeef): to handle case where node has changed addrs, + // need to subscribe to new updates for target node pub to + // attempt to connect to other addrs + // + // * just to to fresh w/ call to node addrs and de-dup? + } + + return nil +} + +// TODO(roasbeef): more specific keychain interface? + +// UnpackAndRecoverSingles is a one-shot method, that given a set of packed +// single channel backups, will restore the channel state to a channel shell, +// and also reach out to connect to any of the known node addresses for that +// channel. It is assumes that after this method exists, if a connection we +// able to be established, then then PeerConnector will continue to attempt to +// re-establish a persistent connection in the background. +func UnpackAndRecoverSingles(singles PackedSingles, + keyChain keychain.KeyRing, restorer ChannelRestorer, + peerConnector PeerConnector) error { + + chanBackups, err := singles.Unpack(keyChain) + if err != nil { + return err + } + + return Recover(chanBackups, restorer, peerConnector) +} + +// UnpackAndRecoverMulti is a one-shot method, that given a set of packed +// multi-channel backups, will restore the channel states to channel shells, +// and also reach out to connect to any of the known node addresses for that +// channel. It is assumes that after this method exists, if a connection we +// able to be established, then then PeerConnector will continue to attempt to +// re-establish a persistent connection in the background. +func UnpackAndRecoverMulti(packedMulti PackedMulti, + keyChain keychain.KeyRing, restorer ChannelRestorer, + peerConnector PeerConnector) error { + + chanBackups, err := packedMulti.Unpack(keyChain) + if err != nil { + return err + } + + return Recover(chanBackups.StaticBackups, restorer, peerConnector) +} diff --git a/chanbackup/recover_test.go b/chanbackup/recover_test.go new file mode 100644 index 00000000..e2b9d71e --- /dev/null +++ b/chanbackup/recover_test.go @@ -0,0 +1,232 @@ +package chanbackup + +import ( + "bytes" + "fmt" + "net" + "testing" + + "github.com/btcsuite/btcd/btcec" +) + +type mockChannelRestorer struct { + fail bool + + callCount int +} + +func (m *mockChannelRestorer) RestoreChansFromSingles(...Single) error { + if m.fail { + return fmt.Errorf("fail") + } + + m.callCount++ + + return nil +} + +type mockPeerConnector struct { + fail bool + + callCount int +} + +func (m *mockPeerConnector) ConnectPeer(node *btcec.PublicKey, + addrs []net.Addr) error { + + if m.fail { + return fmt.Errorf("fail") + } + + m.callCount++ + + return nil +} + +// TestUnpackAndRecoverSingles tests that we're able to properly unpack and +// recover a set of packed singles. +func TestUnpackAndRecoverSingles(t *testing.T) { + t.Parallel() + + keyRing := &mockKeyRing{} + + // First, we'll create a number of single chan backups that we'll + // shortly back to so we can begin our recovery attempt. + numSingles := 10 + backups := make([]Single, 0, numSingles) + var packedBackups PackedSingles + for i := 0; i < numSingles; i++ { + channel, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable make channel: %v", err) + } + + single := NewSingle(channel, nil) + + var b bytes.Buffer + if err := single.PackToWriter(&b, keyRing); err != nil { + t.Fatalf("unable to pack single: %v", err) + } + + backups = append(backups, single) + packedBackups = append(packedBackups, b.Bytes()) + } + + chanRestorer := mockChannelRestorer{} + peerConnector := mockPeerConnector{} + + // Now that we have our backups (packed and unpacked), we'll attempt to + // restore them all in a single batch. + + // If we make the channel restore fail, then the entire method should + // as well + chanRestorer.fail = true + err := UnpackAndRecoverSingles( + packedBackups, keyRing, &chanRestorer, &peerConnector, + ) + if err == nil { + t.Fatalf("restoration should have failed") + } + + chanRestorer.fail = false + + // If we make the peer connector fail, then the entire method should as + // well + peerConnector.fail = true + err = UnpackAndRecoverSingles( + packedBackups, keyRing, &chanRestorer, &peerConnector, + ) + if err == nil { + t.Fatalf("restoration should have failed") + } + + chanRestorer.callCount-- + peerConnector.fail = false + + // Next, we'll ensure that if all the interfaces function as expected, + // then the channels will properly be unpacked and restored. + err = UnpackAndRecoverSingles( + packedBackups, keyRing, &chanRestorer, &peerConnector, + ) + if err != nil { + t.Fatalf("unable to recover chans: %v", err) + } + + // Both the restorer, and connector should have been called 10 times, + // once for each backup. + if chanRestorer.callCount != numSingles { + t.Fatalf("expected %v calls, instead got %v", + numSingles, chanRestorer.callCount) + } + if peerConnector.callCount != numSingles { + t.Fatalf("expected %v calls, instead got %v", + numSingles, peerConnector.callCount) + } + + // If we modify the keyRing, then unpacking should fail. + keyRing.fail = true + err = UnpackAndRecoverSingles( + packedBackups, keyRing, &chanRestorer, &peerConnector, + ) + if err == nil { + t.Fatalf("unpacking should have failed") + } + + // TODO(roasbeef): verify proper call args +} + +// TestUnpackAndRecoverMulti tests that we're able to properly unpack and +// recover a packed multi. +func TestUnpackAndRecoverMulti(t *testing.T) { + t.Parallel() + + keyRing := &mockKeyRing{} + + // First, we'll create a number of single chan backups that we'll + // shortly back to so we can begin our recovery attempt. + numSingles := 10 + backups := make([]Single, 0, numSingles) + for i := 0; i < numSingles; i++ { + channel, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable make channel: %v", err) + } + + single := NewSingle(channel, nil) + + backups = append(backups, single) + } + + multi := Multi{ + StaticBackups: backups, + } + + var b bytes.Buffer + if err := multi.PackToWriter(&b, keyRing); err != nil { + t.Fatalf("unable to pack multi: %v", err) + } + + // Next, we'll pack the set of singles into a packed multi, and also + // create the set of interfaces we need to carry out the remainder of + // the test. + packedMulti := PackedMulti(b.Bytes()) + + chanRestorer := mockChannelRestorer{} + peerConnector := mockPeerConnector{} + + // If we make the channel restore fail, then the entire method should + // as well + chanRestorer.fail = true + err := UnpackAndRecoverMulti( + packedMulti, keyRing, &chanRestorer, &peerConnector, + ) + if err == nil { + t.Fatalf("restoration should have failed") + } + + chanRestorer.fail = false + + // If we make the peer connector fail, then the entire method should as + // well + peerConnector.fail = true + err = UnpackAndRecoverMulti( + packedMulti, keyRing, &chanRestorer, &peerConnector, + ) + if err == nil { + t.Fatalf("restoration should have failed") + } + + chanRestorer.callCount-- + peerConnector.fail = false + + // Next, we'll ensure that if all the interfaces function as expected, + // then the channels will properly be unpacked and restored. + err = UnpackAndRecoverMulti( + packedMulti, keyRing, &chanRestorer, &peerConnector, + ) + if err != nil { + t.Fatalf("unable to recover chans: %v", err) + } + + // Both the restorer, and connector should have been called 10 times, + // once for each backup. + if chanRestorer.callCount != numSingles { + t.Fatalf("expected %v calls, instead got %v", + numSingles, chanRestorer.callCount) + } + if peerConnector.callCount != numSingles { + t.Fatalf("expected %v calls, instead got %v", + numSingles, peerConnector.callCount) + } + + // If we modify the keyRing, then unpacking should fail. + keyRing.fail = true + err = UnpackAndRecoverMulti( + packedMulti, keyRing, &chanRestorer, &peerConnector, + ) + if err == nil { + t.Fatalf("unpacking should have failed") + } + + // TODO(roasbeef): verify proper call args +} diff --git a/chanbackup/single.go b/chanbackup/single.go new file mode 100644 index 00000000..5cef73b8 --- /dev/null +++ b/chanbackup/single.go @@ -0,0 +1,346 @@ +package chanbackup + +import ( + "bytes" + "fmt" + "io" + "net" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" +) + +// SingleBackupVersion denotes the version of the single static channel backup. +// Based on this version, we know how to pack/unpack serialized versions of the +// backup. +type SingleBackupVersion byte + +const ( + // DefaultSingleVersion is the defautl version of the single channel + // backup. The seralized version of this static channel backup is + // simply: version || SCB. Where SCB is the known format of the + // version. + DefaultSingleVersion = 0 +) + +// Single is a static description of an existing channel that can be used for +// the purposes of backing up. The fields in this struct allow a node to +// recover the settled funds within a channel in the case of partial or +// complete data loss. We provide the network address that we last used to +// connect to the peer as well, in case the node stops advertising the IP on +// the network for whatever reason. +// +// TODO(roasbeef): suffix version into struct? +type Single struct { + // Version is the version that should be observed when attempting to + // pack the single backup. + Version SingleBackupVersion + + // ChainHash is a hash which represents the blockchain that this + // channel will be opened within. This value is typically the genesis + // hash. In the case that the original chain went through a contentious + // hard-fork, then this value will be tweaked using the unique fork + // point on each branch. + ChainHash chainhash.Hash + + // FundingOutpoint is the outpoint of the final funding transaction. + // This value uniquely and globally identities the channel within the + // target blockchain as specified by the chain hash parameter. + FundingOutpoint wire.OutPoint + + // ShortChannelID encodes the exact location in the chain in which the + // channel was initially confirmed. This includes: the block height, + // transaction index, and the output within the target transaction. + ShortChannelID lnwire.ShortChannelID + + // RemoteNodePub is the identity public key of the remote node this + // channel has been established with. + RemoteNodePub *btcec.PublicKey + + // Addresses is a list of IP address in which either we were able to + // reach the node over in the past, OR we received an incoming + // authenticated connection for the stored identity public key. + Addresses []net.Addr + + // CsvDelay is the local CSV delay used within the channel. We may need + // this value to reconstruct our script to recover the funds on-chain + // after a force close. + CsvDelay uint16 + + // PaymentBasePoint describes how to derive base public that's used to + // deriving the key used within the non-delayed pay-to-self output on + // the commitment transaction for a node. With this information, we can + // re-derive the private key needed to sweep the funds on-chain. + PaymentBasePoint keychain.KeyLocator + + // ShaChainRootDesc describes how to derive the private key that was + // used as the shachain root for this channel. + ShaChainRootDesc keychain.KeyDescriptor +} + +// NewSingle creates a new static channel backup based on an existing open +// channel. We also pass in the set of addresses that we used in the past to +// connect to the channel peer. +func NewSingle(channel *channeldb.OpenChannel, + nodeAddrs []net.Addr) Single { + + chanCfg := channel.LocalChanCfg + + // TODO(roasbeef): update after we start to store the KeyLoc for + // shachain root + + // We'll need to obtain the shachain root which is derived directly + // from a private key in our keychain. + var b bytes.Buffer + channel.RevocationProducer.Encode(&b) // Can't return an error. + + // Once we have the root, we'll make a public key from it, such that + // the backups plaintext don't carry any private information. When we + // go to recover, we'll present this in order to derive the private + // key. + _, shaChainPoint := btcec.PrivKeyFromBytes(btcec.S256(), b.Bytes()) + + return Single{ + ChainHash: channel.ChainHash, + FundingOutpoint: channel.FundingOutpoint, + ShortChannelID: channel.ShortChannelID, + RemoteNodePub: channel.IdentityPub, + Addresses: nodeAddrs, + CsvDelay: chanCfg.CsvDelay, + PaymentBasePoint: chanCfg.PaymentBasePoint.KeyLocator, + ShaChainRootDesc: keychain.KeyDescriptor{ + PubKey: shaChainPoint, + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyRevocationRoot, + }, + }, + } +} + +// Serialize attempts to write out the serialized version of the target +// StaticChannelBackup into the passed io.Writer. +func (s *Single) Serialize(w io.Writer) error { + // Check to ensure that we'll only attempt to serialize a version that + // we're aware of. + switch s.Version { + case DefaultSingleVersion: + default: + return fmt.Errorf("unable to serialize w/ unknown "+ + "version: %v", s.Version) + } + + // If the sha chain root has specified a public key (which is + // optional), then we'll encode it now. + var shaChainPub [33]byte + if s.ShaChainRootDesc.PubKey != nil { + copy( + shaChainPub[:], + s.ShaChainRootDesc.PubKey.SerializeCompressed(), + ) + } + + // First we gather the SCB as is into a temporary buffer so we can + // determine the total length. Before we write out the serialized SCB, + // we write the length which allows us to skip any Singles that we + // don't know of when decoding a multi. + var singleBytes bytes.Buffer + if err := lnwire.WriteElements( + &singleBytes, + s.ChainHash[:], + s.FundingOutpoint, + s.ShortChannelID, + s.RemoteNodePub, + s.Addresses, + s.CsvDelay, + uint32(s.PaymentBasePoint.Family), + s.PaymentBasePoint.Index, + shaChainPub[:], + uint32(s.ShaChainRootDesc.KeyLocator.Family), + s.ShaChainRootDesc.KeyLocator.Index, + ); err != nil { + return err + } + + return lnwire.WriteElements( + w, + byte(s.Version), + uint16(len(singleBytes.Bytes())), + singleBytes.Bytes(), + ) +} + +// PackToWriter is similar to the Serialize method, but takes the operation a +// step further by encryption the raw bytes of the static channel back up. For +// encryption we use the chacah20poly1305 AEAD cipher with a 24 byte nonce and +// 32-byte key size. We use a 24-byte nonce, as we can't ensure that we have a +// global counter to use as a sequence number for nonces, and want to ensure +// that we're able to decrypt these blobs without any additional context. We +// derive the key that we use for encryption via a SHA2 operation of the with +// the golden keychain.KeyFamilyStaticBackup base encryption key. We then take +// the serialized resulting shared secret point, and hash it using sha256 to +// obtain the key that we'll use for encryption. When using the AEAD, we pass +// the nonce as associated data such that we'll be able to package the two +// together for storage. Before writing out the encrypted payload, we prepend +// the nonce to the final blob. +func (s *Single) PackToWriter(w io.Writer, keyRing keychain.KeyRing) error { + // First, we'll serialize the SCB (StaticChannelBackup) into a + // temporary buffer so we can store it in a temporary place before we + // go to encrypt the entire thing. + var rawBytes bytes.Buffer + if err := s.Serialize(&rawBytes); err != nil { + return err + } + + // Finally, we'll encrypt the raw serialized SCB (using the nonce as + // associated data), and write out the ciphertext prepend with the + // nonce that we used to the passed io.Reader. + return encryptPayloadToWriter(rawBytes, w, keyRing) +} + +// Deserialize attempts to read the raw plaintext serialized SCB from the +// passed io.Reader. If the method is successful, then the target +// StaticChannelBackup will be fully populated. +func (s *Single) Deserialize(r io.Reader) error { + // First, we'll need to read the version of this single-back up so we + // can know how to unpack each of the SCB. + var version byte + err := lnwire.ReadElements(r, &version) + if err != nil { + return err + } + + s.Version = SingleBackupVersion(version) + + switch s.Version { + case DefaultSingleVersion: + default: + return fmt.Errorf("unable to de-serialize w/ unknown "+ + "version: %v", s.Version) + } + + var length uint16 + if err := lnwire.ReadElements(r, &length); err != nil { + return err + } + + err = lnwire.ReadElements( + r, s.ChainHash[:], &s.FundingOutpoint, &s.ShortChannelID, + &s.RemoteNodePub, &s.Addresses, &s.CsvDelay, + ) + if err != nil { + return err + } + + var keyFam uint32 + if err := lnwire.ReadElements(r, &keyFam); err != nil { + return err + } + s.PaymentBasePoint.Family = keychain.KeyFamily(keyFam) + + err = lnwire.ReadElements(r, &s.PaymentBasePoint.Index) + if err != nil { + return err + } + + // Finally, we'll parse out the ShaChainRootDesc. + var ( + shaChainPub [33]byte + zeroPub [33]byte + ) + if err := lnwire.ReadElements(r, shaChainPub[:]); err != nil { + return err + } + + // Since this field is optional, we'll check to see if the pubkey has + // ben specified or not. + if !bytes.Equal(shaChainPub[:], zeroPub[:]) { + s.ShaChainRootDesc.PubKey, err = btcec.ParsePubKey( + shaChainPub[:], btcec.S256(), + ) + if err != nil { + return err + } + } + + var shaKeyFam uint32 + if err := lnwire.ReadElements(r, &shaKeyFam); err != nil { + return err + } + s.ShaChainRootDesc.KeyLocator.Family = keychain.KeyFamily(shaKeyFam) + + return lnwire.ReadElements(r, &s.ShaChainRootDesc.KeyLocator.Index) +} + +// UnpackFromReader is similar to Deserialize method, but it expects the passed +// io.Reader to contain an encrypt SCB. Refer to the SerializeAndEncrypt method +// for details w.r.t the encryption scheme used. If we're unable to decrypt the +// payload for whatever reason (wrong key, wrong nonce, etc), then this method +// will return an error. +func (s *Single) UnpackFromReader(r io.Reader, keyRing keychain.KeyRing) error { + plaintext, err := decryptPayloadFromReader(r, keyRing) + if err != nil { + return err + } + + // Finally, we'll pack the bytes into a reader to we can deserialize + // the plaintext bytes of the SCB. + backupReader := bytes.NewReader(plaintext) + return s.Deserialize(backupReader) +} + +// PackStaticChanBackups accepts a set of existing open channels, and a +// keychain.KeyRing, and returns a map of outpoints to the serialized+encrypted +// static channel backups. The passed keyRing should be backed by the users +// root HD seed in order to ensure full determinism. +func PackStaticChanBackups(backups []Single, + keyRing keychain.KeyRing) (map[wire.OutPoint][]byte, error) { + + packedBackups := make(map[wire.OutPoint][]byte) + for _, chanBackup := range backups { + chanPoint := chanBackup.FundingOutpoint + + var b bytes.Buffer + err := chanBackup.PackToWriter(&b, keyRing) + if err != nil { + return nil, fmt.Errorf("unable to pack chan backup "+ + "for %v: %v", chanPoint, err) + } + + packedBackups[chanPoint] = b.Bytes() + } + + return packedBackups, nil +} + +// PackedSingles represents a series of fully packed SCBs. This may be the +// combination of a series of individual SCBs in order to batch their +// unpacking. +type PackedSingles [][]byte + +// Unpack attempts to decrypt the passed set of encrypted SCBs and deserialize +// each one into a new SCB struct. The passed keyRing should be backed by the +// same HD seed as was used to encrypt the set of backups in the first place. +// If we're unable to decrypt any of the back ups, then we'll return an error. +func (p PackedSingles) Unpack(keyRing keychain.KeyRing) ([]Single, error) { + + backups := make([]Single, len(p)) + for i, encryptedBackup := range p { + var backup Single + + backupReader := bytes.NewReader(encryptedBackup) + err := backup.UnpackFromReader(backupReader, keyRing) + if err != nil { + return nil, err + } + + backups[i] = backup + } + + return backups, nil +} + +// TODO(roasbeef): make codec package? diff --git a/chanbackup/single_test.go b/chanbackup/single_test.go new file mode 100644 index 00000000..b63a2226 --- /dev/null +++ b/chanbackup/single_test.go @@ -0,0 +1,342 @@ +package chanbackup + +import ( + "bytes" + "math" + "math/rand" + "net" + "reflect" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" +) + +var ( + chainHash = chainhash.Hash{ + 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, + 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, + 0x4f, 0x2f, 0x6f, 0x25, 0x18, 0xa3, 0xef, 0xb9, + 0x64, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, + } + + op = wire.OutPoint{ + Hash: chainHash, + Index: 4, + } + + addr1, _ = net.ResolveTCPAddr("tcp", "10.0.0.2:9000") + addr2, _ = net.ResolveTCPAddr("tcp", "10.0.0.3:9000") +) + +func assertSingleEqual(t *testing.T, a, b Single) { + t.Helper() + + if a.Version != b.Version { + t.Fatalf("versions don't match: %v vs %v", a.Version, + b.Version) + } + if a.ChainHash != b.ChainHash { + t.Fatalf("chainhash doesn't match: %v vs %v", a.ChainHash, + b.ChainHash) + } + if a.FundingOutpoint != b.FundingOutpoint { + t.Fatalf("chan point doesn't match: %v vs %v", + a.FundingOutpoint, b.FundingOutpoint) + } + if a.ShortChannelID != b.ShortChannelID { + t.Fatalf("chan id doesn't match: %v vs %v", + a.ShortChannelID, b.ShortChannelID) + } + if !a.RemoteNodePub.IsEqual(b.RemoteNodePub) { + t.Fatalf("node pubs don't match %x vs %x", + a.RemoteNodePub.SerializeCompressed(), + b.RemoteNodePub.SerializeCompressed()) + } + if a.CsvDelay != b.CsvDelay { + t.Fatalf("csv delay doesn't match: %v vs %v", a.CsvDelay, + b.CsvDelay) + } + if !reflect.DeepEqual(a.PaymentBasePoint, b.PaymentBasePoint) { + t.Fatalf("base point doesn't match: %v vs %v", + spew.Sdump(a.PaymentBasePoint), + spew.Sdump(b.PaymentBasePoint)) + } + if !reflect.DeepEqual(a.ShaChainRootDesc, b.ShaChainRootDesc) { + t.Fatalf("sha chain point doesn't match: %v vs %v", + spew.Sdump(a.PaymentBasePoint), + spew.Sdump(b.PaymentBasePoint)) + } + + if len(a.Addresses) != len(b.Addresses) { + t.Fatalf("expected %v addrs got %v", len(a.Addresses), + len(b.Addresses)) + } + for i := 0; i < len(a.Addresses); i++ { + if a.Addresses[i].String() != b.Addresses[i].String() { + t.Fatalf("addr mismatch: %v vs %v", + a.Addresses[i], b.Addresses[i]) + } + } +} + +func genRandomOpenChannelShell() (*channeldb.OpenChannel, error) { + var testPriv [32]byte + if _, err := rand.Read(testPriv[:]); err != nil { + return nil, err + } + + _, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:]) + + var chanPoint wire.OutPoint + if _, err := rand.Read(chanPoint.Hash[:]); err != nil { + return nil, err + } + + pub.Curve = nil + + chanPoint.Index = uint32(rand.Intn(math.MaxUint16)) + + var shaChainRoot [32]byte + if _, err := rand.Read(shaChainRoot[:]); err != nil { + return nil, err + } + + shaChainProducer := shachain.NewRevocationProducer(shaChainRoot) + + return &channeldb.OpenChannel{ + ChainHash: chainHash, + FundingOutpoint: chanPoint, + ShortChannelID: lnwire.NewShortChanIDFromInt( + uint64(rand.Int63()), + ), + IdentityPub: pub, + LocalChanCfg: channeldb.ChannelConfig{ + ChannelConstraints: channeldb.ChannelConstraints{ + CsvDelay: uint16(rand.Int63()), + }, + PaymentBasePoint: keychain.KeyDescriptor{ + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamily(rand.Int63()), + Index: uint32(rand.Int63()), + }, + }, + }, + RevocationProducer: shaChainProducer, + }, nil +} + +// TestSinglePackUnpack tests that we're able to unpack a previously packed +// channel backup. +func TestSinglePackUnpack(t *testing.T) { + t.Parallel() + + // Given our test pub key, we'll create an open channel shell that + // contains all the information we need to create a static channel + // backup. + channel, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to gen open channel: %v", err) + } + + singleChanBackup := NewSingle(channel, []net.Addr{addr1, addr2}) + singleChanBackup.RemoteNodePub.Curve = nil + + keyRing := &mockKeyRing{} + + versionTestCases := []struct { + // version is the pack/unpack version that we should use to + // decode/encode the final SCB. + version SingleBackupVersion + + // valid tests us if this test case should pass or not. + valid bool + }{ + // The default version, should pack/unpack with no problem. + { + version: DefaultSingleVersion, + valid: true, + }, + + // A non-default version, atm this should result in a failure. + { + version: 99, + valid: false, + }, + } + for i, versionCase := range versionTestCases { + // First, we'll re-assign SCB version to what was indicated in + // the test case. + singleChanBackup.Version = versionCase.version + + var b bytes.Buffer + + err := singleChanBackup.PackToWriter(&b, keyRing) + switch { + // If this is a valid test case, and we failed, then we'll + // return an error. + case err != nil && versionCase.valid: + t.Fatalf("#%v, unable to pack single: %v", i, err) + + // If this is an invalid test case, and we passed it, then + // we'll return an error. + case err == nil && !versionCase.valid: + t.Fatalf("#%v got nil error for invalid pack: %v", + i, err) + } + + // If this is a valid test case, then we'll continue to ensure + // we can unpack it, and also that if we mutate the packed + // version, then we trigger an error. + if versionCase.valid { + var unpackedSingle Single + err = unpackedSingle.UnpackFromReader(&b, keyRing) + if err != nil { + t.Fatalf("#%v unable to unpack single: %v", + i, err) + } + unpackedSingle.RemoteNodePub.Curve = nil + + assertSingleEqual(t, singleChanBackup, unpackedSingle) + + // If this was a valid packing attempt, then we'll test + // to ensure that if we mutate the version prepended to + // the serialization, then unpacking will fail as well. + var rawSingle bytes.Buffer + err := unpackedSingle.Serialize(&rawSingle) + if err != nil { + t.Fatalf("unable to serialize single: %v", err) + } + + rawBytes := rawSingle.Bytes() + rawBytes[0] ^= 1 + + newReader := bytes.NewReader(rawBytes) + err = unpackedSingle.Deserialize(newReader) + if err == nil { + t.Fatalf("#%v unpack with unknown version "+ + "should have failed", i) + } + } + } +} + +// TestPackedSinglesUnpack tests that we're able to properly unpack a series of +// packed singles. +func TestPackedSinglesUnpack(t *testing.T) { + t.Parallel() + + keyRing := &mockKeyRing{} + + // To start, we'll create 10 new singles, and them assemble their + // packed forms into a slice. + numSingles := 10 + packedSingles := make([][]byte, 0, numSingles) + unpackedSingles := make([]Single, 0, numSingles) + for i := 0; i < numSingles; i++ { + channel, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to gen channel: %v", err) + } + + single := NewSingle(channel, nil) + + var b bytes.Buffer + if err := single.PackToWriter(&b, keyRing); err != nil { + t.Fatalf("unable to pack single: %v", err) + } + + packedSingles = append(packedSingles, b.Bytes()) + unpackedSingles = append(unpackedSingles, single) + } + + // With all singles packed, we'll create the grouped type and attempt + // to Unpack all of them in a single go. + freshSingles, err := PackedSingles(packedSingles).Unpack(keyRing) + if err != nil { + t.Fatalf("unable to unpack singles: %v", err) + } + + // The set of freshly unpacked singles should exactly match the initial + // set of singles that we packed before. + for i := 0; i < len(unpackedSingles); i++ { + assertSingleEqual(t, unpackedSingles[i], freshSingles[i]) + } + + // If we mutate one of the packed singles, then the entire method + // should fail. + packedSingles[0][0] ^= 1 + _, err = PackedSingles(packedSingles).Unpack(keyRing) + if err == nil { + t.Fatalf("unpack attempt should fail") + } +} + +// TestSinglePackStaticChanBackups tests that we're able to batch pack a set of +// Singles, and then unpack them obtaining the same set of unpacked singles. +func TestSinglePackStaticChanBackups(t *testing.T) { + t.Parallel() + + keyRing := &mockKeyRing{} + + // First, we'll create a set of random single, and along the way, + // create a map that will let us look up each single by its chan point. + numSingles := 10 + singleMap := make(map[wire.OutPoint]Single, numSingles) + unpackedSingles := make([]Single, 0, numSingles) + for i := 0; i < numSingles; i++ { + channel, err := genRandomOpenChannelShell() + if err != nil { + t.Fatalf("unable to gen channel: %v", err) + } + + single := NewSingle(channel, nil) + + singleMap[channel.FundingOutpoint] = single + unpackedSingles = append(unpackedSingles, single) + } + + // Now that we have all of our singles are created, we'll attempt to + // pack them all in a single batch. + packedSingleMap, err := PackStaticChanBackups(unpackedSingles, keyRing) + if err != nil { + t.Fatalf("unable to pack backups: %v", err) + } + + // With our packed singles obtained, we'll ensure that each of them + // match their unpacked counterparts after they themselves have been + // unpacked. + for chanPoint, single := range singleMap { + packedSingles, ok := packedSingleMap[chanPoint] + if !ok { + t.Fatalf("unable to find single %v", chanPoint) + } + + var freshSingle Single + err := freshSingle.UnpackFromReader( + bytes.NewReader(packedSingles), keyRing, + ) + if err != nil { + t.Fatalf("unable to unpack single: %v", err) + } + + assertSingleEqual(t, single, freshSingle) + } + + // If we attempt to pack again, but force the key ring to fail, then + // the entire method should fail. + _, err = PackStaticChanBackups( + unpackedSingles, &mockKeyRing{true}, + ) + if err == nil { + t.Fatalf("pack attempt should fail") + } +} + +// TODO(roasbsef): fuzz parsing diff --git a/keychain/derivation.go b/keychain/derivation.go index f63587d7..d908da75 100644 --- a/keychain/derivation.go +++ b/keychain/derivation.go @@ -83,6 +83,13 @@ const ( // in order to establish a transport session with us on the Lightning // p2p level (BOLT-0008). KeyFamilyNodeKey KeyFamily = 6 + + // KeyFamilyStaticBackup is the family of keys that will be used to + // derive keys that we use to encrypt and decrypt our set of static + // backups. These backups may either be stored within watch towers for + // a payment, or self stored on disk in a single file containing all + // the static channel backups. + KeyFamilyStaticBackup KeyFamily = 7 ) // KeyLocator is a two-tuple that can be used to derive *any* key that has ever