From e86babe1339adf8f920b365fb3e8c94277a08963 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Sun, 9 Dec 2018 18:30:59 -0800 Subject: [PATCH] keychain: extend DerivePrivKey to derive based on pubkey+KeyFamily In this commit, we extend the DerivePrivKey method to allow callers that don't know the full KeyLocator information to attempt to derive a private key via a brute force mechanism. If we don't now the full KeyLoactor, then given the KeyFamily, we can walk down the derivation path and compare keys one by one. In order to ensure we don' t enter an infinite loop when given an unknown public key, we cap the number of keys derived at 100k. An upcoming feature to lnd that adds static channel backups will utilize this feature, as we need to derive the shachain root given only the public key and key family, as we don't currently store this KeyLocator on disk. --- keychain/btcwallet.go | 74 +++++++++++++++++++++++++++++++------- keychain/derivation.go | 23 ++++++++++-- keychain/interface_test.go | 60 ++++++++++++++++++++++++++++++- 3 files changed, 142 insertions(+), 15 deletions(-) diff --git a/keychain/btcwallet.go b/keychain/btcwallet.go index e1feecfc..b8a47b12 100644 --- a/keychain/btcwallet.go +++ b/keychain/btcwallet.go @@ -269,24 +269,74 @@ func (b *BtcWalletKeyRing) DerivePrivKey(keyDesc KeyDescriptor) (*btcec.PrivateK return err } - // Now that we know the account exists, we can safely derive - // the full private key from the given path. - path := waddrmgr.DerivationPath{ + // If the public key isn't set or they have a non-zero index, + // then we know that the caller instead knows the derivation + // path for a key. + if keyDesc.PubKey == nil || keyDesc.Index > 0 { + // Now that we know the account exists, we can safely + // derive the full private key from the given path. + path := waddrmgr.DerivationPath{ + Account: uint32(keyDesc.Family), + Branch: 0, + Index: uint32(keyDesc.Index), + } + addr, err := scope.DeriveFromKeyPath(addrmgrNs, path) + if err != nil { + return err + } + + key, err = addr.(waddrmgr.ManagedPubKeyAddress).PrivKey() + if err != nil { + return err + } + + return nil + } + + // If the public key isn't nil, then this indicates that we + // need to scan for the private key, assuming that we know the + // valid key family. + nextPath := waddrmgr.DerivationPath{ Account: uint32(keyDesc.Family), Branch: 0, - Index: uint32(keyDesc.Index), - } - addr, err := scope.DeriveFromKeyPath(addrmgrNs, path) - if err != nil { - return err + Index: 0, } - key, err = addr.(waddrmgr.ManagedPubKeyAddress).PrivKey() - if err != nil { - return err + // We'll now iterate through our key range in an attempt to + // find the target public key. + // + // TODO(roasbeef): possibly move scanning into wallet to allow + // to be parallelized + for i := 0; i < MaxKeyRangeScan; i++ { + // Derive the next key in the range and fetch its + // managed address. + addr, err := scope.DeriveFromKeyPath( + addrmgrNs, nextPath, + ) + if err != nil { + return err + } + managedAddr := addr.(waddrmgr.ManagedPubKeyAddress) + + // If this is the target public key, then we'll return + // it directly back to the caller. + if managedAddr.PubKey().IsEqual(keyDesc.PubKey) { + key, err = managedAddr.PrivKey() + if err != nil { + return err + } + + return nil + } + + // This wasn't the target key, so roll forward and try + // the next one. + nextPath.Index++ } - return nil + // If we reach this point, then we we're unable to derive the + // private key, so return an error back to the user. + return ErrCannotDerivePrivKey }) if err != nil { return nil, err diff --git a/keychain/derivation.go b/keychain/derivation.go index c0ad4355..f63587d7 100644 --- a/keychain/derivation.go +++ b/keychain/derivation.go @@ -1,6 +1,10 @@ package keychain -import "github.com/btcsuite/btcd/btcec" +import ( + "fmt" + + "github.com/btcsuite/btcd/btcec" +) const ( // KeyDerivationVersion is the version of the key derivation schema @@ -20,6 +24,18 @@ const ( BIP0043Purpose = 1017 ) +var ( + // MaxKeyRangeScan is the maximum number of keys that we'll attempt to + // scan with if a caller knows the public key, but not the KeyLocator + // and wishes to derive a private key. + MaxKeyRangeScan = 100000 + + // ErrCannotDerivePrivKey is returned when DerivePrivKey is unable to + // derive a private key given only the public key and target key + // family. + ErrCannotDerivePrivKey = fmt.Errorf("unable to derive private key") +) + // KeyFamily represents a "family" of keys that will be used within various // contracts created by lnd. These families are meant to be distinct branches // within the HD key chain of the backing wallet. Usage of key families within @@ -141,7 +157,10 @@ type SecretKeyRing interface { KeyRing // DerivePrivKey attempts to derive the private key that corresponds to - // the passed key descriptor. + // the passed key descriptor. If the public key is set, then this + // method will perform an in-order scan over the key set, with a max of + // MaxKeyRangeScan keys. In order for this to work, the caller MUST set + // the KeyFamily within the partially populated KeyLocator. DerivePrivKey(keyDesc KeyDescriptor) (*btcec.PrivateKey, error) // ScalarMult performs a scalar multiplication (ECDH-like operation) diff --git a/keychain/interface_test.go b/keychain/interface_test.go index a297ef96..82e70e7c 100644 --- a/keychain/interface_test.go +++ b/keychain/interface_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcwallet/waddrmgr" @@ -316,7 +317,7 @@ func TestSecretKeyRingDerivation(t *testing.T) { defer cleanUp() success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) { - // First, each key family, we'll ensure that we're able + // For, each key family, we'll ensure that we're able // to obtain the private key of a randomly select child // index within the key family. for _, keyFam := range versionZeroKeyFamilies { @@ -356,6 +357,57 @@ func TestSecretKeyRingDerivation(t *testing.T) { privKey.PubKey().SerializeCompressed()) } + // Next, we'll test that we're able to derive a + // key given only the public key and key + // family. + // + // Derive a new key from the key ring. + keyDesc, err := secretKeyRing.DeriveNextKey(keyFam) + if err != nil { + t.Fatalf("unable to derive key: %v", err) + } + + // We'll now construct a key descriptor that + // requires us to scan the key range, and query + // for the key, we should be able to find it as + // it's valid. + keyDesc = KeyDescriptor{ + PubKey: keyDesc.PubKey, + KeyLocator: KeyLocator{ + Family: keyFam, + }, + } + privKey, err = secretKeyRing.DerivePrivKey(keyDesc) + if err != nil { + t.Fatalf("unable to derive priv key "+ + "via scanning: %v", err) + } + + // Having to resort to scanning, we should be + // able to find the target public key. + if !keyDesc.PubKey.IsEqual(privKey.PubKey()) { + t.Fatalf("pubkeys mismatched: expected %x, got %x", + pubKeyDesc.PubKey.SerializeCompressed(), + privKey.PubKey().SerializeCompressed()) + } + + // We'll try again, but this time with an + // unknown public key. + _, pub := btcec.PrivKeyFromBytes( + btcec.S256(), testHDSeed[:], + ) + keyDesc.PubKey = pub + + // If we attempt to query for this key, then we + // should get ErrCannotDerivePrivKey. + privKey, err = secretKeyRing.DerivePrivKey( + keyDesc, + ) + if err != ErrCannotDerivePrivKey { + t.Fatalf("expected %T, instead got %v", + ErrCannotDerivePrivKey, err) + } + // TODO(roasbeef): scalar mult once integrated } }) @@ -364,3 +416,9 @@ func TestSecretKeyRingDerivation(t *testing.T) { } } } + +func init() { + // We'll clamp the max range scan to constrain the run time of the + // private key scan test. + MaxKeyRangeScan = 3 +}