157 lines
4.0 KiB
Go
157 lines
4.0 KiB
Go
|
package chanbackup
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/btcsuite/btcd/btcec"
|
||
|
"github.com/lightningnetwork/lnd/keychain"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
testWalletPrivKey = []byte{
|
||
|
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
|
||
|
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
|
||
|
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
|
||
|
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
type mockKeyRing struct {
|
||
|
fail bool
|
||
|
}
|
||
|
|
||
|
func (m *mockKeyRing) DeriveNextKey(keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
|
||
|
return keychain.KeyDescriptor{}, nil
|
||
|
}
|
||
|
func (m *mockKeyRing) DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) {
|
||
|
if m.fail {
|
||
|
return keychain.KeyDescriptor{}, fmt.Errorf("fail")
|
||
|
}
|
||
|
|
||
|
_, pub := btcec.PrivKeyFromBytes(btcec.S256(), testWalletPrivKey)
|
||
|
return keychain.KeyDescriptor{
|
||
|
PubKey: pub,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
// TestEncryptDecryptPayload tests that given a static key, we're able to
|
||
|
// properly decrypt and encrypted payload. We also test that we'll reject a
|
||
|
// ciphertext that has been modified.
|
||
|
func TestEncryptDecryptPayload(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
payloadCases := []struct {
|
||
|
// plaintext is the string that we'll be encrypting.
|
||
|
plaintext []byte
|
||
|
|
||
|
// mutator allows a test case to modify the ciphertext before
|
||
|
// we attempt to decrypt it.
|
||
|
mutator func(*[]byte)
|
||
|
|
||
|
// valid indicates if this test should pass or fail.
|
||
|
valid bool
|
||
|
}{
|
||
|
// Proper payload, should decrypt.
|
||
|
{
|
||
|
plaintext: []byte("payload test plain text"),
|
||
|
mutator: nil,
|
||
|
valid: true,
|
||
|
},
|
||
|
|
||
|
// Mutator modifies cipher text, shouldn't decrypt.
|
||
|
{
|
||
|
plaintext: []byte("payload test plain text"),
|
||
|
mutator: func(p *[]byte) {
|
||
|
// Flip a byte in the payload to render it invalid.
|
||
|
(*p)[0] ^= 1
|
||
|
},
|
||
|
valid: false,
|
||
|
},
|
||
|
|
||
|
// Cipher text is too small, shouldn't decrypt.
|
||
|
{
|
||
|
plaintext: []byte("payload test plain text"),
|
||
|
mutator: func(p *[]byte) {
|
||
|
// Modify the cipher text to be zero length.
|
||
|
*p = []byte{}
|
||
|
},
|
||
|
valid: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
keyRing := &mockKeyRing{}
|
||
|
|
||
|
for i, payloadCase := range payloadCases {
|
||
|
var cipherBuffer bytes.Buffer
|
||
|
|
||
|
// First, we'll encrypt the passed payload with our scheme.
|
||
|
payloadReader := bytes.NewBuffer(payloadCase.plaintext)
|
||
|
err := encryptPayloadToWriter(
|
||
|
*payloadReader, &cipherBuffer, keyRing,
|
||
|
)
|
||
|
if err != nil {
|
||
|
t.Fatalf("unable encrypt paylaod: %v", err)
|
||
|
}
|
||
|
|
||
|
// If we have a mutator, then we'll wrong the mutator over the
|
||
|
// cipher text, then reset the main buffer and re-write the new
|
||
|
// cipher text.
|
||
|
if payloadCase.mutator != nil {
|
||
|
cipherText := cipherBuffer.Bytes()
|
||
|
|
||
|
payloadCase.mutator(&cipherText)
|
||
|
|
||
|
cipherBuffer.Reset()
|
||
|
cipherBuffer.Write(cipherText)
|
||
|
}
|
||
|
|
||
|
plaintext, err := decryptPayloadFromReader(&cipherBuffer, keyRing)
|
||
|
|
||
|
switch {
|
||
|
// If this was meant to be a valid decryption, but we failed,
|
||
|
// then we'll return an error.
|
||
|
case err != nil && payloadCase.valid:
|
||
|
t.Fatalf("unable to decrypt valid payload case %v", i)
|
||
|
|
||
|
// If this was meant to be an invalid decryption, and we didn't
|
||
|
// fail, then we'll return an error.
|
||
|
case err == nil && !payloadCase.valid:
|
||
|
t.Fatalf("payload was invalid yet was able to decrypt")
|
||
|
}
|
||
|
|
||
|
// Only if this case was mean to be valid will we ensure the
|
||
|
// resulting decrypted plaintext matches the original input.
|
||
|
if payloadCase.valid &&
|
||
|
!bytes.Equal(plaintext, payloadCase.plaintext) {
|
||
|
t.Fatalf("#%v: expected %v, got %v: ", i,
|
||
|
payloadCase.plaintext, plaintext)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// TestInvalidKeyEncryption tests that encryption fails if we're unable to
|
||
|
// obtain a valid key.
|
||
|
func TestInvalidKeyEncryption(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
var b bytes.Buffer
|
||
|
err := encryptPayloadToWriter(b, &b, &mockKeyRing{true})
|
||
|
if err == nil {
|
||
|
t.Fatalf("expected error due to fail key gen")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// TestInvalidKeyDecrytion tests that decryption fails if we're unable to
|
||
|
// obtain a valid key.
|
||
|
func TestInvalidKeyDecrytion(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
var b bytes.Buffer
|
||
|
_, err := decryptPayloadFromReader(&b, &mockKeyRing{true})
|
||
|
if err == nil {
|
||
|
t.Fatalf("expected error due to fail key gen")
|
||
|
}
|
||
|
}
|