keychain: add initial set up interface-level tests
This commit is contained in:
parent
d6f54b30fc
commit
cd068d8888
255
keychain/interface_test.go
Normal file
255
keychain/interface_test.go
Normal file
@ -0,0 +1,255 @@
|
||||
package keychain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/roasbeef/btcd/chaincfg"
|
||||
"github.com/roasbeef/btcd/chaincfg/chainhash"
|
||||
"github.com/roasbeef/btcwallet/wallet"
|
||||
"github.com/roasbeef/btcwallet/walletdb"
|
||||
|
||||
_ "github.com/roasbeef/btcwallet/walletdb/bdb" // Required in order to create the default database.
|
||||
)
|
||||
|
||||
// versionZeroKeyFamilies is a slice of all the known key families for first
|
||||
// version of the key derivation schema defined in this package.
|
||||
var versionZeroKeyFamilies = []KeyFamily{
|
||||
KeyFamilyMultiSig,
|
||||
KeyFamilyRevocationBase,
|
||||
KeyFamilyHtlcBase,
|
||||
KeyFamilyPaymentBase,
|
||||
KeyFamilyDelayBase,
|
||||
KeyFamilyRevocationRoot,
|
||||
KeyFamilyNodeKey,
|
||||
}
|
||||
|
||||
var (
|
||||
testHDSeed = chainhash.Hash{
|
||||
0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab,
|
||||
0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4,
|
||||
0x4f, 0x2f, 0x6f, 0x25, 0x98, 0xa3, 0xef, 0xb9,
|
||||
0x69, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53,
|
||||
}
|
||||
)
|
||||
|
||||
func createTestBtcWallet() (func(), *wallet.Wallet, error) {
|
||||
tempDir, err := ioutil.TempDir("", "keyring-lnwallet")
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
loader := wallet.NewLoader(&chaincfg.SimNetParams, tempDir)
|
||||
|
||||
pass := []byte("test")
|
||||
|
||||
baseWallet, err := loader.CreateNewWallet(pass, pass, testHDSeed[:])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := baseWallet.Unlock(pass, nil); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// We'll now ensure that the KeyScope: (201, 1) exists within the
|
||||
// internal waddrmgr. We'll need this in order to properly generate the
|
||||
// keys required for signing various contracts.
|
||||
_, err = baseWallet.Manager.FetchScopedKeyManager(lightningKeyScope)
|
||||
if err != nil {
|
||||
err := walletdb.Update(baseWallet.Database(), func(tx walletdb.ReadWriteTx) error {
|
||||
addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey)
|
||||
|
||||
_, err := baseWallet.Manager.NewScopedKeyManager(
|
||||
addrmgrNs, lightningKeyScope, lightningAddrSchema,
|
||||
)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cleanUp := func() {
|
||||
baseWallet.Lock()
|
||||
os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
return cleanUp, baseWallet, nil
|
||||
}
|
||||
|
||||
// secretKeyRingConstructor is a function signature that's used as a generic
|
||||
// constructor for various implementations of the KeyRing interface. A string
|
||||
// naming the returned interface, a function closure that cleans up any
|
||||
// resources, and the clean up interface itself are to be returned.
|
||||
type keyRingConstructor func() (string, func(), KeyRing, error)
|
||||
|
||||
// TestKeyRingDerivation tests that each known KeyRing implementation properly
|
||||
// adheres to the expected behavior of the set of interfaces.
|
||||
func TestKeyRingDerivation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
keyRingImplementations := []keyRingConstructor{
|
||||
func() (string, func(), KeyRing, error) {
|
||||
cleanUp, wallet, err := createTestBtcWallet()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create wallet: %v", err)
|
||||
}
|
||||
|
||||
keyRing := NewBtcWalletKeyRing(wallet)
|
||||
|
||||
return "btcwallet", cleanUp, keyRing, nil
|
||||
},
|
||||
}
|
||||
|
||||
// For each implementation constructor registered above, we'll execute
|
||||
// an identical set of tests in order to ensure that the interface
|
||||
// adheres to our nominal specification.
|
||||
for _, keyRingConstructor := range keyRingImplementations {
|
||||
keyRingName, cleanUp, keyRing, err := keyRingConstructor()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create key ring %v: %v", keyRingName,
|
||||
err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) {
|
||||
// First, we'll ensure that we're able to derive keys
|
||||
// from each of the known key families.
|
||||
for _, keyFam := range versionZeroKeyFamilies {
|
||||
// First, we'll ensure that we can derive the
|
||||
// *next* key in the keychain.
|
||||
keyDesc, err := keyRing.DeriveNextKey(keyFam)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to derive next for "+
|
||||
"keyFam=%v: %v", keyFam, err)
|
||||
}
|
||||
|
||||
// If we now try to manually derive the *first*
|
||||
// key, then we should get an identical public
|
||||
// key back.
|
||||
keyLoc := KeyLocator{
|
||||
Family: keyFam,
|
||||
Index: 0,
|
||||
}
|
||||
firstKeyDesc, err := keyRing.DeriveKey(keyLoc)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to derive first key for "+
|
||||
"keyFam=%v: %v", keyFam, err)
|
||||
}
|
||||
|
||||
if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) {
|
||||
t.Fatalf("mismatched keys: expected %v, "+
|
||||
"got %x",
|
||||
keyDesc.PubKey.SerializeCompressed(),
|
||||
firstKeyDesc.PubKey.SerializeCompressed())
|
||||
}
|
||||
|
||||
// If this succeeds, then we'll also try to
|
||||
// derive a random index within the range.
|
||||
randKeyIndex := uint32(rand.Int31())
|
||||
keyLoc = KeyLocator{
|
||||
Family: keyFam,
|
||||
Index: randKeyIndex,
|
||||
}
|
||||
_, err = keyRing.DeriveKey(keyLoc)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to derive key_index=%v "+
|
||||
"for keyFam=%v: %v",
|
||||
randKeyIndex, keyFam, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
if !success {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// secretKeyRingConstructor is a function signature that's used as a generic
|
||||
// constructor for various implementations of the SecretKeyRing interface. A
|
||||
// string naming the returned interface, a function closure that cleans up any
|
||||
// resources, and the clean up interface itself are to be returned.
|
||||
type secretKeyRingConstructor func() (string, func(), SecretKeyRing, error)
|
||||
|
||||
// TestSecretKeyRingDerivation tests that each known SecretKeyRing
|
||||
// implementation properly adheres to the expected behavior of the set of
|
||||
// interface.
|
||||
func TestSecretKeyRingDerivation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
secretKeyRingImplementations := []secretKeyRingConstructor{
|
||||
func() (string, func(), SecretKeyRing, error) {
|
||||
cleanUp, wallet, err := createTestBtcWallet()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create wallet: %v", err)
|
||||
}
|
||||
|
||||
keyRing := NewBtcWalletKeyRing(wallet)
|
||||
|
||||
return "btcwallet", cleanUp, keyRing, nil
|
||||
},
|
||||
}
|
||||
|
||||
// For each implementation constructor registered above, we'll execute
|
||||
// an identical set of tests in order to ensure that the interface
|
||||
// adheres to our nominal specification.
|
||||
for _, secretKeyRingConstructor := range secretKeyRingImplementations {
|
||||
keyRingName, cleanUp, secretKeyRing, err := secretKeyRingConstructor()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create secret key ring %v: %v",
|
||||
keyRingName, err)
|
||||
}
|
||||
defer cleanUp()
|
||||
|
||||
success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) {
|
||||
// First, each key family, we'll ensure that we're able
|
||||
// to obtain the private key of a randomly select child
|
||||
// index within the key family.
|
||||
for _, keyFam := range versionZeroKeyFamilies {
|
||||
randKeyIndex := uint32(rand.Int31())
|
||||
keyLoc := KeyLocator{
|
||||
Family: keyFam,
|
||||
Index: randKeyIndex,
|
||||
}
|
||||
|
||||
// First, we'll query for the public key for
|
||||
// this target key locator.
|
||||
pubKeyDesc, err := secretKeyRing.DeriveKey(keyLoc)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to derive pubkey "+
|
||||
"(fam=%v, index=%v): %v",
|
||||
keyLoc.Family,
|
||||
keyLoc.Index, err)
|
||||
}
|
||||
|
||||
// With the public key derive, ensure that
|
||||
// we're able to obtain the corresponding
|
||||
// private key correctly.
|
||||
privKey, err := secretKeyRing.DerivePrivKey(KeyDescriptor{
|
||||
KeyLocator: keyLoc,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to derive priv "+
|
||||
"(fam=%v, index=%v): %v", keyLoc.Family,
|
||||
keyLoc.Index, err)
|
||||
}
|
||||
|
||||
// Finally, ensure that the keys match up
|
||||
// properly.
|
||||
if !pubKeyDesc.PubKey.IsEqual(privKey.PubKey()) {
|
||||
t.Fatalf("pubkeys mismatched: expected %x, got %x",
|
||||
pubKeyDesc.PubKey.SerializeCompressed(),
|
||||
privKey.PubKey().SerializeCompressed())
|
||||
}
|
||||
|
||||
// TODO(roasbeef): scalar mult once integrated
|
||||
}
|
||||
})
|
||||
if !success {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user