lnd version, "hacked" to enable seedless restore from xprv + scb
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

276 lines
6.5 KiB

package shachain
import (
"reflect"
"testing"
"github.com/go-errors/errors"
)
// bitsToIndex is a helper function which takes 'n' last bits as input and
// create shachain index.
// Example:
// Input: 0,1,1,0,0
// Output: 0b000000000000000000000000000000000000000[01100] == 12
func bitsToIndex(bs ...uint64) (index, error) {
if len(bs) > 64 {
return 0, errors.New("number of elements should be lower then" +
" 64")
}
var res uint64
for i, e := range bs {
if e != 1 && e != 0 {
return 0, errors.New("wrong element, should be '0' or" +
" '1'")
}
res += e * 1 << uint(len(bs)-i-1)
}
return index(res), nil
}
type deriveTest struct {
name string
from index
to index
position []uint8
shouldFail bool
}
func generateTests(t *testing.T) []deriveTest {
var (
tests []deriveTest
from index
to index
err error
)
from, err = bitsToIndex(0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
to, err = bitsToIndex(0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
tests = append(tests, deriveTest{
name: "zero 'from' 'to'",
from: from,
to: to,
position: nil,
shouldFail: false,
})
from, err = bitsToIndex(0, 1, 0, 0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
to, err = bitsToIndex(0, 1, 0, 0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
tests = append(tests, deriveTest{
name: "same indexes #1",
from: from,
to: to,
position: nil,
shouldFail: false,
})
from, err = bitsToIndex(1)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
to, err = bitsToIndex(0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
tests = append(tests, deriveTest{
name: "same indexes #2",
from: from,
to: to,
shouldFail: true,
})
from, err = bitsToIndex(0, 0, 0, 0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
to, err = bitsToIndex(0, 0, 1, 0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
tests = append(tests, deriveTest{
name: "test seed 'from'",
from: from,
to: to,
position: []uint8{1},
shouldFail: false,
})
from, err = bitsToIndex(1, 1, 0, 0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
to, err = bitsToIndex(0, 1, 0, 0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
tests = append(tests, deriveTest{
name: "not the same indexes",
from: from,
to: to,
shouldFail: true,
})
from, err = bitsToIndex(1, 0, 1, 0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
to, err = bitsToIndex(1, 0, 0, 0)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
tests = append(tests, deriveTest{
name: "'from' index greater then 'to' index",
from: from,
to: to,
shouldFail: true,
})
from, err = bitsToIndex(1)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
to, err = bitsToIndex(1)
if err != nil {
t.Fatalf("can't generate from index: %v", err)
}
tests = append(tests, deriveTest{
name: "zero number trailing zeros",
from: from,
to: to,
position: nil,
shouldFail: false,
})
return tests
}
// TestDeriveIndex check the correctness of index derive function by testing
// the index corner cases.
func TestDeriveIndex(t *testing.T) {
t.Parallel()
for _, test := range generateTests(t) {
pos, err := test.from.deriveBitTransformations(test.to)
if err != nil {
if !test.shouldFail {
t.Fatalf("Failed (%v): %v", test.name, err)
}
} else {
if test.shouldFail {
t.Fatalf("Failed (%v): test should failed "+
"but it's not", test.name)
}
if !reflect.DeepEqual(pos, test.position) {
t.Fatalf("Failed(%v): position is wrong real:"+
"%v expected:%v", test.name, pos, test.position)
}
}
t.Logf("Passed: %v", test.name)
}
}
// deriveElementTests encodes the test vectors specified in BOLT-03,
// Appendix D, Generation Tests.
var deriveElementTests = []struct {
name string
index index
output string
seed string
shouldFail bool
}{
{
name: "generate_from_seed 0 final node",
seed: "0000000000000000000000000000000000000000000000000000000000000000",
index: 0xffffffffffff,
output: "02a40c85b6f28da08dfdbe0926c53fab2de6d28c10301f8f7c4073d5e42e3148",
shouldFail: false,
},
{
name: "generate_from_seed FF final node",
seed: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
index: 0xffffffffffff,
output: "7cc854b54e3e0dcdb010d7a3fee464a9687be6e8db3be6854c475621e007a5dc",
shouldFail: false,
},
{
name: "generate_from_seed FF alternate bits 1",
index: 0xaaaaaaaaaaa,
output: "56f4008fb007ca9acf0e15b054d5c9fd12ee06cea347914ddbaed70d1c13a528",
seed: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
shouldFail: false,
},
{
name: "generate_from_seed FF alternate bits 2",
index: 0x555555555555,
output: "9015daaeb06dba4ccc05b91b2f73bd54405f2be9f217fbacd3c5ac2e62327d31",
seed: "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF",
shouldFail: false,
},
{
name: "generate_from_seed 01 last nontrivial node",
index: 1,
output: "915c75942a26bb3a433a8ce2cb0427c29ec6c1775cfc78328b57f6ba7bfeaa9c",
seed: "0101010101010101010101010101010101010101010101010101010101010101",
shouldFail: false,
},
}
// TestSpecificationDeriveElement is used to check the consistency with
// specification hash derivation function.
func TestSpecificationDeriveElement(t *testing.T) {
t.Parallel()
for _, test := range deriveElementTests {
// Generate seed element.
element, err := newElementFromStr(test.seed, rootIndex)
if err != nil {
t.Fatal(err)
}
// Derive element by index.
result, err := element.derive(test.index)
if err != nil {
if !test.shouldFail {
t.Fatalf("Failed (%v): %v", test.name, err)
}
} else {
if test.shouldFail {
t.Fatalf("Failed (%v): test should failed "+
"but it's not", test.name)
}
// Generate element which we should get after derivation.
output, err := newElementFromStr(test.output, test.index)
if err != nil {
t.Fatal(err)
}
// Check that they are equal.
if !result.isEqual(output) {
t.Fatalf("Failed (%v): hash is wrong, real:"+
"%v expected:%v", test.name,
result.hash.String(), output.hash.String())
}
}
t.Logf("Passed (%v)", test.name)
}
}