diff --git a/keychain/interface_test.go b/keychain/interface_test.go index ceae8dd3..ef513cab 100644 --- a/keychain/interface_test.go +++ b/keychain/interface_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "math/rand" "os" + "runtime" "testing" "time" @@ -13,6 +14,7 @@ import ( "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/walletdb" + "github.com/davecgh/go-spew/spew" _ "github.com/btcsuite/btcwallet/walletdb/bdb" // Required in order to create the default database. ) @@ -91,6 +93,14 @@ func createTestBtcWallet(coinType uint32) (func(), *wallet.Wallet, error) { return cleanUp, baseWallet, nil } +func assertEqualKeyLocator(t *testing.T, a, b KeyLocator) { + _, _, line, _ := runtime.Caller(1) + if a != b { + t.Fatalf("line #%v: mismatched key locators: expected %v, "+ + "got %v", line, spew.Sdump(a), spew.Sdump(b)) + } +} + // secretKeyRingConstructor is a function signature that's used as a generic // constructor for various implementations of the KeyRing interface. A string // naming the returned interface, a function closure that cleans up any @@ -141,6 +151,8 @@ func TestKeyRingDerivation(t *testing.T) { }, } + const numKeysToDerive = 10 + // For each implementation constructor registered above, we'll execute // an identical set of tests in order to ensure that the interface // adheres to our nominal specification. @@ -163,10 +175,16 @@ func TestKeyRingDerivation(t *testing.T) { t.Fatalf("unable to derive next for "+ "keyFam=%v: %v", keyFam, err) } + assertEqualKeyLocator(t, + KeyLocator{ + Family: keyFam, + Index: 0, + }, keyDesc.KeyLocator, + ) - // If we now try to manually derive the *first* - // key, then we should get an identical public - // key back. + // We'll now re-derive that key to ensure that + // we're able to properly access the key via + // the random access derivation methods. keyLoc := KeyLocator{ Family: keyFam, Index: 0, @@ -176,13 +194,41 @@ func TestKeyRingDerivation(t *testing.T) { t.Fatalf("unable to derive first key for "+ "keyFam=%v: %v", keyFam, err) } - if !keyDesc.PubKey.IsEqual(firstKeyDesc.PubKey) { - t.Fatalf("mismatched keys: expected %v, "+ + t.Fatalf("mismatched keys: expected %x, "+ "got %x", keyDesc.PubKey.SerializeCompressed(), firstKeyDesc.PubKey.SerializeCompressed()) } + assertEqualKeyLocator(t, + KeyLocator{ + Family: keyFam, + Index: 0, + }, firstKeyDesc.KeyLocator, + ) + + // If we now try to manually derive the next 10 + // keys (including the original key), then we + // should get an identical public key back and + // their KeyLocator information + // should be set properly. + for i := 0; i < numKeysToDerive+1; i++ { + keyLoc := KeyLocator{ + Family: keyFam, + Index: uint32(i), + } + keyDesc, err := keyRing.DeriveKey(keyLoc) + if err != nil { + t.Fatalf("unable to derive first key for "+ + "keyFam=%v: %v", keyFam, err) + } + + // Ensure that the key locator matches + // up as well. + assertEqualKeyLocator( + t, keyLoc, keyDesc.KeyLocator, + ) + } // If this succeeds, then we'll also try to // derive a random index within the range. @@ -191,12 +237,15 @@ func TestKeyRingDerivation(t *testing.T) { Family: keyFam, Index: randKeyIndex, } - _, err = keyRing.DeriveKey(keyLoc) + keyDesc, err = keyRing.DeriveKey(keyLoc) if err != nil { t.Fatalf("unable to derive key_index=%v "+ "for keyFam=%v: %v", randKeyIndex, keyFam, err) } + assertEqualKeyLocator( + t, keyLoc, keyDesc.KeyLocator, + ) } }) if !success {