keychain: extend TestKeyRingDerivation to check KeyLocators of derived keys

This commit is contained in:
Olaoluwa Osuntokun 2018-08-13 19:20:57 -07:00
parent ad25ae1a07
commit cf06b041a4
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21

View File

@ -5,6 +5,7 @@ import (
"io/ioutil"
"math/rand"
"os"
"runtime"
"testing"
"time"
@ -13,6 +14,7 @@ import (
"github.com/btcsuite/btcwallet/waddrmgr"
"github.com/btcsuite/btcwallet/wallet"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/davecgh/go-spew/spew"
_ "github.com/btcsuite/btcwallet/walletdb/bdb" // Required in order to create the default database.
)
@ -91,6 +93,14 @@ func createTestBtcWallet(coinType uint32) (func(), *wallet.Wallet, error) {
return cleanUp, baseWallet, nil
}
func assertEqualKeyLocator(t *testing.T, a, b KeyLocator) {
_, _, line, _ := runtime.Caller(1)
if a != b {
t.Fatalf("line #%v: mismatched key locators: expected %v, "+
"got %v", line, spew.Sdump(a), spew.Sdump(b))
}
}
// secretKeyRingConstructor is a function signature that's used as a generic
// constructor for various implementations of the KeyRing interface. A string
// naming the returned interface, a function closure that cleans up any
@ -141,6 +151,8 @@ func TestKeyRingDerivation(t *testing.T) {
},
}
const numKeysToDerive = 10
// For each implementation constructor registered above, we'll execute
// an identical set of tests in order to ensure that the interface
// adheres to our nominal specification.
@ -163,10 +175,16 @@ func TestKeyRingDerivation(t *testing.T) {
t.Fatalf("unable to derive next for "+
"keyFam=%v: %v", keyFam, err)
}
assertEqualKeyLocator(t,
KeyLocator{
Family: keyFam,
Index: 0,
}, keyDesc.KeyLocator,
)
// If we now try to manually derive the *first*
// key, then we should get an identical public
// key back.
// We'll now re-derive that key to ensure that
// we're able to properly access the key via
// the random access derivation methods.
keyLoc := KeyLocator{
Family: keyFam,
Index: 0,
@ -176,13 +194,41 @@ func TestKeyRingDerivation(t *testing.T) {
t.Fatalf("unable to derive first key for "+
"keyFam=%v: %v", keyFam, err)
}
if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) {
t.Fatalf("mismatched keys: expected %v, "+
t.Fatalf("mismatched keys: expected %x, "+
"got %x",
keyDesc.PubKey.SerializeCompressed(),
firstKeyDesc.PubKey.SerializeCompressed())
}
assertEqualKeyLocator(t,
KeyLocator{
Family: keyFam,
Index: 0,
}, firstKeyDesc.KeyLocator,
)
// If we now try to manually derive the next 10
// keys (including the original key), then we
// should get an identical public key back and
// their KeyLocator information
// should be set properly.
for i := 0; i < numKeysToDerive+1; i++ {
keyLoc := KeyLocator{
Family: keyFam,
Index: uint32(i),
}
keyDesc, err := keyRing.DeriveKey(keyLoc)
if err != nil {
t.Fatalf("unable to derive first key for "+
"keyFam=%v: %v", keyFam, err)
}
// Ensure that the key locator matches
// up as well.
assertEqualKeyLocator(
t, keyLoc, keyDesc.KeyLocator,
)
}
// If this succeeds, then we'll also try to
// derive a random index within the range.
@ -191,12 +237,15 @@ func TestKeyRingDerivation(t *testing.T) {
Family: keyFam,
Index: randKeyIndex,
}
_, err = keyRing.DeriveKey(keyLoc)
keyDesc, err = keyRing.DeriveKey(keyLoc)
if err != nil {
t.Fatalf("unable to derive key_index=%v "+
"for keyFam=%v: %v",
randKeyIndex, keyFam, err)
}
assertEqualKeyLocator(
t, keyLoc, keyDesc.KeyLocator,
)
}
})
if !success {