package keychain import ( "fmt" "io/ioutil" "math/rand" "os" "testing" "github.com/roasbeef/btcd/chaincfg" "github.com/roasbeef/btcd/chaincfg/chainhash" "github.com/roasbeef/btcwallet/wallet" "github.com/roasbeef/btcwallet/walletdb" _ "github.com/roasbeef/btcwallet/walletdb/bdb" // Required in order to create the default database. ) // versionZeroKeyFamilies is a slice of all the known key families for first // version of the key derivation schema defined in this package. var versionZeroKeyFamilies = []KeyFamily{ KeyFamilyMultiSig, KeyFamilyRevocationBase, KeyFamilyHtlcBase, KeyFamilyPaymentBase, KeyFamilyDelayBase, KeyFamilyRevocationRoot, KeyFamilyNodeKey, } var ( testHDSeed = chainhash.Hash{ 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, 0x4f, 0x2f, 0x6f, 0x25, 0x98, 0xa3, 0xef, 0xb9, 0x69, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, } ) func createTestBtcWallet() (func(), *wallet.Wallet, error) { tempDir, err := ioutil.TempDir("", "keyring-lnwallet") if err != nil { return nil, nil, err } loader := wallet.NewLoader(&chaincfg.SimNetParams, tempDir) pass := []byte("test") baseWallet, err := loader.CreateNewWallet(pass, pass, testHDSeed[:]) if err != nil { return nil, nil, err } if err := baseWallet.Unlock(pass, nil); err != nil { return nil, nil, err } // We'll now ensure that the KeyScope: (201, 1) exists within the // internal waddrmgr. We'll need this in order to properly generate the // keys required for signing various contracts. _, err = baseWallet.Manager.FetchScopedKeyManager(lightningKeyScope) if err != nil { err := walletdb.Update(baseWallet.Database(), func(tx walletdb.ReadWriteTx) error { addrmgrNs := tx.ReadWriteBucket(waddrmgrNamespaceKey) _, err := baseWallet.Manager.NewScopedKeyManager( addrmgrNs, lightningKeyScope, lightningAddrSchema, ) return err }) if err != nil { return nil, nil, err } } cleanUp := func() { baseWallet.Lock() os.RemoveAll(tempDir) } return cleanUp, baseWallet, nil } // secretKeyRingConstructor is a function signature that's used as a generic // constructor for various implementations of the KeyRing interface. A string // naming the returned interface, a function closure that cleans up any // resources, and the clean up interface itself are to be returned. type keyRingConstructor func() (string, func(), KeyRing, error) // TestKeyRingDerivation tests that each known KeyRing implementation properly // adheres to the expected behavior of the set of interfaces. func TestKeyRingDerivation(t *testing.T) { t.Parallel() keyRingImplementations := []keyRingConstructor{ func() (string, func(), KeyRing, error) { cleanUp, wallet, err := createTestBtcWallet() if err != nil { t.Fatalf("unable to create wallet: %v", err) } keyRing := NewBtcWalletKeyRing(wallet) return "btcwallet", cleanUp, keyRing, nil }, } // For each implementation constructor registered above, we'll execute // an identical set of tests in order to ensure that the interface // adheres to our nominal specification. for _, keyRingConstructor := range keyRingImplementations { keyRingName, cleanUp, keyRing, err := keyRingConstructor() if err != nil { t.Fatalf("unable to create key ring %v: %v", keyRingName, err) } defer cleanUp() success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) { // First, we'll ensure that we're able to derive keys // from each of the known key families. for _, keyFam := range versionZeroKeyFamilies { // First, we'll ensure that we can derive the // *next* key in the keychain. keyDesc, err := keyRing.DeriveNextKey(keyFam) if err != nil { t.Fatalf("unable to derive next for "+ "keyFam=%v: %v", keyFam, err) } // If we now try to manually derive the *first* // key, then we should get an identical public // key back. keyLoc := KeyLocator{ Family: keyFam, Index: 0, } firstKeyDesc, err := keyRing.DeriveKey(keyLoc) if err != nil { t.Fatalf("unable to derive first key for "+ "keyFam=%v: %v", keyFam, err) } if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) { t.Fatalf("mismatched keys: expected %v, "+ "got %x", keyDesc.PubKey.SerializeCompressed(), firstKeyDesc.PubKey.SerializeCompressed()) } // If this succeeds, then we'll also try to // derive a random index within the range. randKeyIndex := uint32(rand.Int31()) keyLoc = KeyLocator{ Family: keyFam, Index: randKeyIndex, } _, err = keyRing.DeriveKey(keyLoc) if err != nil { t.Fatalf("unable to derive key_index=%v "+ "for keyFam=%v: %v", randKeyIndex, keyFam, err) } } }) if !success { break } } } // secretKeyRingConstructor is a function signature that's used as a generic // constructor for various implementations of the SecretKeyRing interface. A // string naming the returned interface, a function closure that cleans up any // resources, and the clean up interface itself are to be returned. type secretKeyRingConstructor func() (string, func(), SecretKeyRing, error) // TestSecretKeyRingDerivation tests that each known SecretKeyRing // implementation properly adheres to the expected behavior of the set of // interface. func TestSecretKeyRingDerivation(t *testing.T) { t.Parallel() secretKeyRingImplementations := []secretKeyRingConstructor{ func() (string, func(), SecretKeyRing, error) { cleanUp, wallet, err := createTestBtcWallet() if err != nil { t.Fatalf("unable to create wallet: %v", err) } keyRing := NewBtcWalletKeyRing(wallet) return "btcwallet", cleanUp, keyRing, nil }, } // For each implementation constructor registered above, we'll execute // an identical set of tests in order to ensure that the interface // adheres to our nominal specification. for _, secretKeyRingConstructor := range secretKeyRingImplementations { keyRingName, cleanUp, secretKeyRing, err := secretKeyRingConstructor() if err != nil { t.Fatalf("unable to create secret key ring %v: %v", keyRingName, err) } defer cleanUp() success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) { // First, each key family, we'll ensure that we're able // to obtain the private key of a randomly select child // index within the key family. for _, keyFam := range versionZeroKeyFamilies { randKeyIndex := uint32(rand.Int31()) keyLoc := KeyLocator{ Family: keyFam, Index: randKeyIndex, } // First, we'll query for the public key for // this target key locator. pubKeyDesc, err := secretKeyRing.DeriveKey(keyLoc) if err != nil { t.Fatalf("unable to derive pubkey "+ "(fam=%v, index=%v): %v", keyLoc.Family, keyLoc.Index, err) } // With the public key derive, ensure that // we're able to obtain the corresponding // private key correctly. privKey, err := secretKeyRing.DerivePrivKey(KeyDescriptor{ KeyLocator: keyLoc, }) if err != nil { t.Fatalf("unable to derive priv "+ "(fam=%v, index=%v): %v", keyLoc.Family, keyLoc.Index, err) } // Finally, ensure that the keys match up // properly. if !pubKeyDesc.PubKey.IsEqual(privKey.PubKey()) { t.Fatalf("pubkeys mismatched: expected %x, got %x", pubKeyDesc.PubKey.SerializeCompressed(), privKey.PubKey().SerializeCompressed()) } // TODO(roasbeef): scalar mult once integrated } }) if !success { break } } }