diff --git a/chanbackup/crypto.go b/chanbackup/crypto.go new file mode 100644 index 00000000..8fdb46f6 --- /dev/null +++ b/chanbackup/crypto.go @@ -0,0 +1,140 @@ +package chanbackup + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + "io/ioutil" + + "github.com/lightningnetwork/lnd/keychain" + "golang.org/x/crypto/chacha20poly1305" +) + +// TODO(roasbeef): interface in front of? + +// baseEncryptionKeyLoc is the KeyLocator that we'll use to derive the base +// encryption key used for encrypting all static channel backups. We use this +// to then derive the actual key that we'll use for encryption. We do this +// rather than using the raw key, as we assume that we can't obtain the raw +// keys, and we don't want to require that the HSM know our target cipher for +// encryption. +// +// TODO(roasbeef): possibly unique encrypt? +var baseEncryptionKeyLoc = keychain.KeyLocator{ + Family: keychain.KeyFamilyStaticBackup, + Index: 0, +} + +// genEncryptionKey derives the key that we'll use to encrypt all of our static +// channel backups. The key itself, is the sha2 of a base key that we get from +// the keyring. We derive the key this way as we don't force the HSM (or any +// future abstractions) to be able to derive and know of the cipher that we'll +// use within our protocol. +func genEncryptionKey(keyRing keychain.KeyRing) ([]byte, error) { + // key = SHA256(baseKey) + baseKey, err := keyRing.DeriveKey( + baseEncryptionKeyLoc, + ) + if err != nil { + return nil, err + } + + encryptionKey := sha256.Sum256( + baseKey.PubKey.SerializeCompressed(), + ) + + // TODO(roasbeef): throw back in ECDH? + + return encryptionKey[:], nil +} + +// encryptPayloadToWriter attempts to write the set of bytes contained within +// the passed byes.Buffer into the passed io.Writer in an encrypted form. We +// use a 24-byte chachapoly AEAD instance with a randomized nonce that's +// pre-pended to the final payload and used as associated data in the AEAD. We +// use the passed keyRing to generate the encryption key, see genEncryptionKey +// for further details. +func encryptPayloadToWriter(payload bytes.Buffer, w io.Writer, + keyRing keychain.KeyRing) error { + + // First, we'll derive the key that we'll use to encrypt the payload + // for safe storage without giving away the details of any of our + // channels. The final operation is: + // + // key = SHA256(baseKey) + encryptionKey, err := genEncryptionKey(keyRing) + if err != nil { + return err + } + + // Before encryption, we'll initialize our cipher with the target + // encryption key, and also read out our random 24-byte nonce we use + // for encryption. Note that we use NewX, not New, as the latter + // version requires a 12-byte nonce, not a 24-byte nonce. + cipher, err := chacha20poly1305.NewX(encryptionKey) + if err != nil { + return err + } + var nonce [chacha20poly1305.NonceSizeX]byte + if _, err := rand.Read(nonce[:]); err != nil { + return err + } + + // Finally, we encrypted the final payload, and write out our + // ciphertext with nonce pre-pended. + ciphertext := cipher.Seal(nil, nonce[:], payload.Bytes(), nonce[:]) + + if _, err := w.Write(nonce[:]); err != nil { + return err + } + if _, err := w.Write(ciphertext); err != nil { + return err + } + + return nil +} + +// decryptPayloadFromReader attempts to decrypt the encrypted bytes within the +// passed io.Reader instance using the key derived from the passed keyRing. For +// further details regarding the key derivation protocol, see the +// genEncryptionKey method. +func decryptPayloadFromReader(payload io.Reader, + keyRing keychain.KeyRing) ([]byte, error) { + + // First, we'll re-generate the encryption key that we use for all the + // SCBs. + encryptionKey, err := genEncryptionKey(keyRing) + if err != nil { + return nil, err + } + + // Next, we'll read out the entire blob as we need to isolate the nonce + // from the rest of the ciphertext. + packedBackup, err := ioutil.ReadAll(payload) + if err != nil { + return nil, err + } + if len(packedBackup) < chacha20poly1305.NonceSizeX { + return nil, fmt.Errorf("payload size too small, must be at "+ + "least %v bytes", chacha20poly1305.NonceSizeX) + } + + nonce := packedBackup[:chacha20poly1305.NonceSizeX] + ciphertext := packedBackup[chacha20poly1305.NonceSizeX:] + + // Now that we have the cipher text and the nonce separated, we can go + // ahead and decrypt the final blob so we can properly serialized the + // SCB. + cipher, err := chacha20poly1305.NewX(encryptionKey) + if err != nil { + return nil, err + } + plaintext, err := cipher.Open(nil, nonce, ciphertext, nonce) + if err != nil { + return nil, err + } + + return plaintext, nil +} diff --git a/chanbackup/crypto_test.go b/chanbackup/crypto_test.go new file mode 100644 index 00000000..6b4b27fe --- /dev/null +++ b/chanbackup/crypto_test.go @@ -0,0 +1,156 @@ +package chanbackup + +import ( + "bytes" + "fmt" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/keychain" +) + +var ( + testWalletPrivKey = []byte{ + 0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf, + 0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9, + 0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f, + 0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90, + } +) + +type mockKeyRing struct { + fail bool +} + +func (m *mockKeyRing) DeriveNextKey(keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) { + return keychain.KeyDescriptor{}, nil +} +func (m *mockKeyRing) DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) { + if m.fail { + return keychain.KeyDescriptor{}, fmt.Errorf("fail") + } + + _, pub := btcec.PrivKeyFromBytes(btcec.S256(), testWalletPrivKey) + return keychain.KeyDescriptor{ + PubKey: pub, + }, nil +} + +// TestEncryptDecryptPayload tests that given a static key, we're able to +// properly decrypt and encrypted payload. We also test that we'll reject a +// ciphertext that has been modified. +func TestEncryptDecryptPayload(t *testing.T) { + t.Parallel() + + payloadCases := []struct { + // plaintext is the string that we'll be encrypting. + plaintext []byte + + // mutator allows a test case to modify the ciphertext before + // we attempt to decrypt it. + mutator func(*[]byte) + + // valid indicates if this test should pass or fail. + valid bool + }{ + // Proper payload, should decrypt. + { + plaintext: []byte("payload test plain text"), + mutator: nil, + valid: true, + }, + + // Mutator modifies cipher text, shouldn't decrypt. + { + plaintext: []byte("payload test plain text"), + mutator: func(p *[]byte) { + // Flip a byte in the payload to render it invalid. + (*p)[0] ^= 1 + }, + valid: false, + }, + + // Cipher text is too small, shouldn't decrypt. + { + plaintext: []byte("payload test plain text"), + mutator: func(p *[]byte) { + // Modify the cipher text to be zero length. + *p = []byte{} + }, + valid: false, + }, + } + + keyRing := &mockKeyRing{} + + for i, payloadCase := range payloadCases { + var cipherBuffer bytes.Buffer + + // First, we'll encrypt the passed payload with our scheme. + payloadReader := bytes.NewBuffer(payloadCase.plaintext) + err := encryptPayloadToWriter( + *payloadReader, &cipherBuffer, keyRing, + ) + if err != nil { + t.Fatalf("unable encrypt paylaod: %v", err) + } + + // If we have a mutator, then we'll wrong the mutator over the + // cipher text, then reset the main buffer and re-write the new + // cipher text. + if payloadCase.mutator != nil { + cipherText := cipherBuffer.Bytes() + + payloadCase.mutator(&cipherText) + + cipherBuffer.Reset() + cipherBuffer.Write(cipherText) + } + + plaintext, err := decryptPayloadFromReader(&cipherBuffer, keyRing) + + switch { + // If this was meant to be a valid decryption, but we failed, + // then we'll return an error. + case err != nil && payloadCase.valid: + t.Fatalf("unable to decrypt valid payload case %v", i) + + // If this was meant to be an invalid decryption, and we didn't + // fail, then we'll return an error. + case err == nil && !payloadCase.valid: + t.Fatalf("payload was invalid yet was able to decrypt") + } + + // Only if this case was mean to be valid will we ensure the + // resulting decrypted plaintext matches the original input. + if payloadCase.valid && + !bytes.Equal(plaintext, payloadCase.plaintext) { + t.Fatalf("#%v: expected %v, got %v: ", i, + payloadCase.plaintext, plaintext) + } + } +} + +// TestInvalidKeyEncryption tests that encryption fails if we're unable to +// obtain a valid key. +func TestInvalidKeyEncryption(t *testing.T) { + t.Parallel() + + var b bytes.Buffer + err := encryptPayloadToWriter(b, &b, &mockKeyRing{true}) + if err == nil { + t.Fatalf("expected error due to fail key gen") + } +} + +// TestInvalidKeyDecrytion tests that decryption fails if we're unable to +// obtain a valid key. +func TestInvalidKeyDecrytion(t *testing.T) { + t.Parallel() + + var b bytes.Buffer + _, err := decryptPayloadFromReader(&b, &mockKeyRing{true}) + if err == nil { + t.Fatalf("expected error due to fail key gen") + } +}