diff --git a/amp/derivation_test.go b/amp/derivation_test.go new file mode 100644 index 00000000..4ed533be --- /dev/null +++ b/amp/derivation_test.go @@ -0,0 +1,113 @@ +package amp_test + +import ( + "testing" + + "github.com/lightningnetwork/lnd/amp" + "github.com/stretchr/testify/require" +) + +type sharerTest struct { + name string + numShares int +} + +var sharerTests = []sharerTest{ + { + name: "root only", + numShares: 1, + }, + { + name: "two shares", + numShares: 2, + }, + { + name: "many shares", + numShares: 10, + }, +} + +// TestSharer executes the end-to-end derivation between sender and receiver, +// asserting that shares are properly computed and, when reconstructed by the +// receiver, produce identical child hashes and preimages as the sender. +func TestSharer(t *testing.T) { + for _, test := range sharerTests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + testSharer(t, test) + }) + } +} + +func testSharer(t *testing.T, test sharerTest) { + // Construct a new sharer with a random seed. + var ( + sharer amp.Sharer + err error + ) + sharer, err = amp.NewSeedSharer() + require.NoError(t, err) + + // Assert that we can instantiate an equivalent root sharer using the + // root share. + root := sharer.Root() + sharerFromRoot := amp.SeedSharerFromRoot(&root) + require.Equal(t, sharer, sharerFromRoot) + + // Generate numShares-1 randomized shares. + children := make([]*amp.Child, 0, test.numShares) + for i := 0; i < test.numShares-1; i++ { + var left amp.Sharer + left, sharer, err = sharer.Split() + require.NoError(t, err) + + child := left.Child(0) + + assertChildShare(t, child, 0) + children = append(children, child) + } + + // Compute the final share and finalize the sharing. + child := sharer.Child(0) + + assertChildShare(t, child, 0) + children = append(children, child) + + assertReconstruction(t, children...) +} + +// assertChildShare checks that the child has the expected child index, and that +// the child's preimage is valid for the its hash. +func assertChildShare(t *testing.T, child *amp.Child, expIndex int) { + t.Helper() + + require.Equal(t, uint32(expIndex), child.Index) + require.True(t, child.Preimage.Matches(child.Hash)) +} + +// assertReconstruction takes a list of children and simulates the receiver +// recombining the shares, and then deriving the child preimage and hash for +// each HTLC. This asserts that the receiver can always rederive the full set of +// children knowing only the shares and child indexes for each. +func assertReconstruction(t *testing.T, children ...*amp.Child) { + t.Helper() + + // Reconstruct a child descriptor for each of the provided children. + // In practice, the receiver will only know the share and the child + // index it learns for each HTLC. + descs := make([]amp.ChildDesc, 0, len(children)) + for _, child := range children { + descs = append(descs, amp.ChildDesc{ + Share: child.Share, + Index: child.Index, + }) + } + + // Now, recombine the shares and rederive a child for each of the + // descriptors above. The resulting set of children should exactly match + // the set provided. + children2 := amp.ReconstructChildren(descs...) + require.Equal(t, children, children2) +}