package keychain import ( "fmt" "io/ioutil" "math/rand" "os" "testing" "time" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/walletdb" "github.com/davecgh/go-spew/spew" _ "github.com/btcsuite/btcwallet/walletdb/bdb" // Required in order to create the default database. ) // versionZeroKeyFamilies is a slice of all the known key families for first // version of the key derivation schema defined in this package. var versionZeroKeyFamilies = []KeyFamily{ KeyFamilyMultiSig, KeyFamilyRevocationBase, KeyFamilyHtlcBase, KeyFamilyPaymentBase, KeyFamilyDelayBase, KeyFamilyRevocationRoot, KeyFamilyNodeKey, } var ( testHDSeed = chainhash.Hash{ 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, 0x4f, 0x2f, 0x6f, 0x25, 0x98, 0xa3, 0xef, 0xb9, 0x69, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, } ) func createTestBtcWallet(coinType uint32) (func(), *wallet.Wallet, error) { tempDir, err := ioutil.TempDir("", "keyring-lnwallet") if err != nil { return nil, nil, err } loader := wallet.NewLoader(&chaincfg.SimNetParams, tempDir, true, 0) pass := []byte("test") baseWallet, err := loader.CreateNewWallet( pass, pass, testHDSeed[:], time.Time{}, ) if err != nil { return nil, nil, err } if err := baseWallet.Unlock(pass, nil); err != nil { return nil, nil, err } // Construct the key scope required to derive keys for the chose // coinType. chainKeyScope := waddrmgr.KeyScope{ Purpose: BIP0043Purpose, Coin: coinType, } // We'll now ensure that the KeyScope: (1017, coinType) exists within // the internal waddrmgr. We'll need this in order to properly generate // the keys required for signing various contracts. _, err = baseWallet.Manager.FetchScopedKeyManager(chainKeyScope) if err != nil { err := walletdb.Update(baseWallet.Database(), func(tx walletdb.ReadWriteTx) error { addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey) _, err := baseWallet.Manager.NewScopedKeyManager( addrmgrNs, chainKeyScope, lightningAddrSchema, ) return err }) if err != nil { return nil, nil, err } } cleanUp := func() { baseWallet.Lock() os.RemoveAll(tempDir) } return cleanUp, baseWallet, nil } func assertEqualKeyLocator(t *testing.T, a, b KeyLocator) { t.Helper() if a != b { t.Fatalf("mismatched key locators: expected %v, "+ "got %v", spew.Sdump(a), spew.Sdump(b)) } } // secretKeyRingConstructor is a function signature that's used as a generic // constructor for various implementations of the KeyRing interface. A string // naming the returned interface, a function closure that cleans up any // resources, and the clean up interface itself are to be returned. type keyRingConstructor func() (string, func(), KeyRing, error) // TestKeyRingDerivation tests that each known KeyRing implementation properly // adheres to the expected behavior of the set of interfaces. func TestKeyRingDerivation(t *testing.T) { t.Parallel() keyRingImplementations := []keyRingConstructor{ func() (string, func(), KeyRing, error) { cleanUp, wallet, err := createTestBtcWallet( CoinTypeBitcoin, ) if err != nil { t.Fatalf("unable to create wallet: %v", err) } keyRing := NewBtcWalletKeyRing(wallet, CoinTypeBitcoin) return "btcwallet", cleanUp, keyRing, nil }, func() (string, func(), KeyRing, error) { cleanUp, wallet, err := createTestBtcWallet( CoinTypeLitecoin, ) if err != nil { t.Fatalf("unable to create wallet: %v", err) } keyRing := NewBtcWalletKeyRing(wallet, CoinTypeLitecoin) return "ltcwallet", cleanUp, keyRing, nil }, func() (string, func(), KeyRing, error) { cleanUp, wallet, err := createTestBtcWallet( CoinTypeTestnet, ) if err != nil { t.Fatalf("unable to create wallet: %v", err) } keyRing := NewBtcWalletKeyRing(wallet, CoinTypeTestnet) return "testwallet", cleanUp, keyRing, nil }, } const numKeysToDerive = 10 // For each implementation constructor registered above, we'll execute // an identical set of tests in order to ensure that the interface // adheres to our nominal specification. for _, keyRingConstructor := range keyRingImplementations { keyRingName, cleanUp, keyRing, err := keyRingConstructor() if err != nil { t.Fatalf("unable to create key ring %v: %v", keyRingName, err) } defer cleanUp() success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) { // First, we'll ensure that we're able to derive keys // from each of the known key families. for _, keyFam := range versionZeroKeyFamilies { // First, we'll ensure that we can derive the // *next* key in the keychain. keyDesc, err := keyRing.DeriveNextKey(keyFam) if err != nil { t.Fatalf("unable to derive next for "+ "keyFam=%v: %v", keyFam, err) } assertEqualKeyLocator(t, KeyLocator{ Family: keyFam, Index: 0, }, keyDesc.KeyLocator, ) // We'll now re-derive that key to ensure that // we're able to properly access the key via // the random access derivation methods. keyLoc := KeyLocator{ Family: keyFam, Index: 0, } firstKeyDesc, err := keyRing.DeriveKey(keyLoc) if err != nil { t.Fatalf("unable to derive first key for "+ "keyFam=%v: %v", keyFam, err) } if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) { t.Fatalf("mismatched keys: expected %x, "+ "got %x", keyDesc.PubKey.SerializeCompressed(), firstKeyDesc.PubKey.SerializeCompressed()) } assertEqualKeyLocator(t, KeyLocator{ Family: keyFam, Index: 0, }, firstKeyDesc.KeyLocator, ) // If we now try to manually derive the next 10 // keys (including the original key), then we // should get an identical public key back and // their KeyLocator information // should be set properly. for i := 0; i < numKeysToDerive+1; i++ { keyLoc := KeyLocator{ Family: keyFam, Index: uint32(i), } keyDesc, err := keyRing.DeriveKey(keyLoc) if err != nil { t.Fatalf("unable to derive first key for "+ "keyFam=%v: %v", keyFam, err) } // Ensure that the key locator matches // up as well. assertEqualKeyLocator( t, keyLoc, keyDesc.KeyLocator, ) } // If this succeeds, then we'll also try to // derive a random index within the range. randKeyIndex := uint32(rand.Int31()) keyLoc = KeyLocator{ Family: keyFam, Index: randKeyIndex, } keyDesc, err = keyRing.DeriveKey(keyLoc) if err != nil { t.Fatalf("unable to derive key_index=%v "+ "for keyFam=%v: %v", randKeyIndex, keyFam, err) } assertEqualKeyLocator( t, keyLoc, keyDesc.KeyLocator, ) } }) if !success { break } } } // secretKeyRingConstructor is a function signature that's used as a generic // constructor for various implementations of the SecretKeyRing interface. A // string naming the returned interface, a function closure that cleans up any // resources, and the clean up interface itself are to be returned. type secretKeyRingConstructor func() (string, func(), SecretKeyRing, error) // TestSecretKeyRingDerivation tests that each known SecretKeyRing // implementation properly adheres to the expected behavior of the set of // interface. func TestSecretKeyRingDerivation(t *testing.T) { t.Parallel() secretKeyRingImplementations := []secretKeyRingConstructor{ func() (string, func(), SecretKeyRing, error) { cleanUp, wallet, err := createTestBtcWallet( CoinTypeBitcoin, ) if err != nil { t.Fatalf("unable to create wallet: %v", err) } keyRing := NewBtcWalletKeyRing(wallet, CoinTypeBitcoin) return "btcwallet", cleanUp, keyRing, nil }, func() (string, func(), SecretKeyRing, error) { cleanUp, wallet, err := createTestBtcWallet( CoinTypeLitecoin, ) if err != nil { t.Fatalf("unable to create wallet: %v", err) } keyRing := NewBtcWalletKeyRing(wallet, CoinTypeLitecoin) return "ltcwallet", cleanUp, keyRing, nil }, func() (string, func(), SecretKeyRing, error) { cleanUp, wallet, err := createTestBtcWallet( CoinTypeTestnet, ) if err != nil { t.Fatalf("unable to create wallet: %v", err) } keyRing := NewBtcWalletKeyRing(wallet, CoinTypeTestnet) return "testwallet", cleanUp, keyRing, nil }, } // For each implementation constructor registered above, we'll execute // an identical set of tests in order to ensure that the interface // adheres to our nominal specification. for _, secretKeyRingConstructor := range secretKeyRingImplementations { keyRingName, cleanUp, secretKeyRing, err := secretKeyRingConstructor() if err != nil { t.Fatalf("unable to create secret key ring %v: %v", keyRingName, err) } defer cleanUp() success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) { // For, each key family, we'll ensure that we're able // to obtain the private key of a randomly select child // index within the key family. for _, keyFam := range versionZeroKeyFamilies { randKeyIndex := uint32(rand.Int31()) keyLoc := KeyLocator{ Family: keyFam, Index: randKeyIndex, } // First, we'll query for the public key for // this target key locator. pubKeyDesc, err := secretKeyRing.DeriveKey(keyLoc) if err != nil { t.Fatalf("unable to derive pubkey "+ "(fam=%v, index=%v): %v", keyLoc.Family, keyLoc.Index, err) } // With the public key derive, ensure that // we're able to obtain the corresponding // private key correctly. privKey, err := secretKeyRing.DerivePrivKey(KeyDescriptor{ KeyLocator: keyLoc, }) if err != nil { t.Fatalf("unable to derive priv "+ "(fam=%v, index=%v): %v", keyLoc.Family, keyLoc.Index, err) } // Finally, ensure that the keys match up // properly. if !pubKeyDesc.PubKey.IsEqual(privKey.PubKey()) { t.Fatalf("pubkeys mismatched: expected %x, got %x", pubKeyDesc.PubKey.SerializeCompressed(), privKey.PubKey().SerializeCompressed()) } // Next, we'll test that we're able to derive a // key given only the public key and key // family. // // Derive a new key from the key ring. keyDesc, err := secretKeyRing.DeriveNextKey(keyFam) if err != nil { t.Fatalf("unable to derive key: %v", err) } // We'll now construct a key descriptor that // requires us to scan the key range, and query // for the key, we should be able to find it as // it's valid. keyDesc = KeyDescriptor{ PubKey: keyDesc.PubKey, KeyLocator: KeyLocator{ Family: keyFam, }, } privKey, err = secretKeyRing.DerivePrivKey(keyDesc) if err != nil { t.Fatalf("unable to derive priv key "+ "via scanning: %v", err) } // Having to resort to scanning, we should be // able to find the target public key. if !keyDesc.PubKey.IsEqual(privKey.PubKey()) { t.Fatalf("pubkeys mismatched: expected %x, got %x", pubKeyDesc.PubKey.SerializeCompressed(), privKey.PubKey().SerializeCompressed()) } // We'll try again, but this time with an // unknown public key. _, pub := btcec.PrivKeyFromBytes( btcec.S256(), testHDSeed[:], ) keyDesc.PubKey = pub // If we attempt to query for this key, then we // should get ErrCannotDerivePrivKey. privKey, err = secretKeyRing.DerivePrivKey( keyDesc, ) if err != ErrCannotDerivePrivKey { t.Fatalf("expected %T, instead got %v", ErrCannotDerivePrivKey, err) } // TODO(roasbeef): scalar mult once integrated } }) if !success { break } } } func init() { // We'll clamp the max range scan to constrain the run time of the // private key scan test. MaxKeyRangeScan = 3 }