diff --git a/keychain/btcwallet.go b/keychain/btcwallet.go index e1feecfc..b8a47b12 100644 --- a/keychain/btcwallet.go +++ b/keychain/btcwallet.go @@ -269,24 +269,74 @@ func (b *BtcWalletKeyRing) DerivePrivKey(keyDesc KeyDescriptor) (*btcec.PrivateK return err } - // Now that we know the account exists, we can safely derive - // the full private key from the given path. - path := waddrmgr.DerivationPath{ + // If the public key isn't set or they have a non-zero index, + // then we know that the caller instead knows the derivation + // path for a key. + if keyDesc.PubKey == nil || keyDesc.Index > 0 { + // Now that we know the account exists, we can safely + // derive the full private key from the given path. + path := waddrmgr.DerivationPath{ + Account: uint32(keyDesc.Family), + Branch: 0, + Index: uint32(keyDesc.Index), + } + addr, err := scope.DeriveFromKeyPath(addrmgrNs, path) + if err != nil { + return err + } + + key, err = addr.(waddrmgr.ManagedPubKeyAddress).PrivKey() + if err != nil { + return err + } + + return nil + } + + // If the public key isn't nil, then this indicates that we + // need to scan for the private key, assuming that we know the + // valid key family. + nextPath := waddrmgr.DerivationPath{ Account: uint32(keyDesc.Family), Branch: 0, - Index: uint32(keyDesc.Index), - } - addr, err := scope.DeriveFromKeyPath(addrmgrNs, path) - if err != nil { - return err + Index: 0, } - key, err = addr.(waddrmgr.ManagedPubKeyAddress).PrivKey() - if err != nil { - return err + // We'll now iterate through our key range in an attempt to + // find the target public key. + // + // TODO(roasbeef): possibly move scanning into wallet to allow + // to be parallelized + for i := 0; i < MaxKeyRangeScan; i++ { + // Derive the next key in the range and fetch its + // managed address. + addr, err := scope.DeriveFromKeyPath( + addrmgrNs, nextPath, + ) + if err != nil { + return err + } + managedAddr := addr.(waddrmgr.ManagedPubKeyAddress) + + // If this is the target public key, then we'll return + // it directly back to the caller. + if managedAddr.PubKey().IsEqual(keyDesc.PubKey) { + key, err = managedAddr.PrivKey() + if err != nil { + return err + } + + return nil + } + + // This wasn't the target key, so roll forward and try + // the next one. + nextPath.Index++ } - return nil + // If we reach this point, then we we're unable to derive the + // private key, so return an error back to the user. + return ErrCannotDerivePrivKey }) if err != nil { return nil, err diff --git a/keychain/derivation.go b/keychain/derivation.go index c0ad4355..f63587d7 100644 --- a/keychain/derivation.go +++ b/keychain/derivation.go @@ -1,6 +1,10 @@ package keychain -import "github.com/btcsuite/btcd/btcec" +import ( + "fmt" + + "github.com/btcsuite/btcd/btcec" +) const ( // KeyDerivationVersion is the version of the key derivation schema @@ -20,6 +24,18 @@ const ( BIP0043Purpose = 1017 ) +var ( + // MaxKeyRangeScan is the maximum number of keys that we'll attempt to + // scan with if a caller knows the public key, but not the KeyLocator + // and wishes to derive a private key. + MaxKeyRangeScan = 100000 + + // ErrCannotDerivePrivKey is returned when DerivePrivKey is unable to + // derive a private key given only the public key and target key + // family. + ErrCannotDerivePrivKey = fmt.Errorf("unable to derive private key") +) + // KeyFamily represents a "family" of keys that will be used within various // contracts created by lnd. These families are meant to be distinct branches // within the HD key chain of the backing wallet. Usage of key families within @@ -141,7 +157,10 @@ type SecretKeyRing interface { KeyRing // DerivePrivKey attempts to derive the private key that corresponds to - // the passed key descriptor. + // the passed key descriptor. If the public key is set, then this + // method will perform an in-order scan over the key set, with a max of + // MaxKeyRangeScan keys. In order for this to work, the caller MUST set + // the KeyFamily within the partially populated KeyLocator. DerivePrivKey(keyDesc KeyDescriptor) (*btcec.PrivateKey, error) // ScalarMult performs a scalar multiplication (ECDH-like operation) diff --git a/keychain/interface_test.go b/keychain/interface_test.go index a297ef96..82e70e7c 100644 --- a/keychain/interface_test.go +++ b/keychain/interface_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcwallet/waddrmgr" @@ -316,7 +317,7 @@ func TestSecretKeyRingDerivation(t *testing.T) { defer cleanUp() success := t.Run(fmt.Sprintf("%v", keyRingName), func(t *testing.T) { - // First, each key family, we'll ensure that we're able + // For, each key family, we'll ensure that we're able // to obtain the private key of a randomly select child // index within the key family. for _, keyFam := range versionZeroKeyFamilies { @@ -356,6 +357,57 @@ func TestSecretKeyRingDerivation(t *testing.T) { privKey.PubKey().SerializeCompressed()) } + // Next, we'll test that we're able to derive a + // key given only the public key and key + // family. + // + // Derive a new key from the key ring. + keyDesc, err := secretKeyRing.DeriveNextKey(keyFam) + if err != nil { + t.Fatalf("unable to derive key: %v", err) + } + + // We'll now construct a key descriptor that + // requires us to scan the key range, and query + // for the key, we should be able to find it as + // it's valid. + keyDesc = KeyDescriptor{ + PubKey: keyDesc.PubKey, + KeyLocator: KeyLocator{ + Family: keyFam, + }, + } + privKey, err = secretKeyRing.DerivePrivKey(keyDesc) + if err != nil { + t.Fatalf("unable to derive priv key "+ + "via scanning: %v", err) + } + + // Having to resort to scanning, we should be + // able to find the target public key. + if !keyDesc.PubKey.IsEqual(privKey.PubKey()) { + t.Fatalf("pubkeys mismatched: expected %x, got %x", + pubKeyDesc.PubKey.SerializeCompressed(), + privKey.PubKey().SerializeCompressed()) + } + + // We'll try again, but this time with an + // unknown public key. + _, pub := btcec.PrivKeyFromBytes( + btcec.S256(), testHDSeed[:], + ) + keyDesc.PubKey = pub + + // If we attempt to query for this key, then we + // should get ErrCannotDerivePrivKey. + privKey, err = secretKeyRing.DerivePrivKey( + keyDesc, + ) + if err != ErrCannotDerivePrivKey { + t.Fatalf("expected %T, instead got %v", + ErrCannotDerivePrivKey, err) + } + // TODO(roasbeef): scalar mult once integrated } }) @@ -364,3 +416,9 @@ func TestSecretKeyRingDerivation(t *testing.T) { } } } + +func init() { + // We'll clamp the max range scan to constrain the run time of the + // private key scan test. + MaxKeyRangeScan = 3 +}