2d397b12b1
We'll use this AMP-specific ShardTracker for AMP payments. It will be used to derive hashes for each HTLC attempt using the underlying AMP derivation scheme.
96 lines
2.4 KiB
Go
96 lines
2.4 KiB
Go
package amp_test
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"testing"
|
|
|
|
"github.com/lightningnetwork/lnd/amp"
|
|
"github.com/lightningnetwork/lnd/lnwire"
|
|
"github.com/lightningnetwork/lnd/routing/shards"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestAMPShardTracker tests that we can derive and cancel shards at will using
|
|
// the AMP shard tracker.
|
|
func TestAMPShardTracker(t *testing.T) {
|
|
var root, setID, payAddr [32]byte
|
|
_, err := rand.Read(root[:])
|
|
require.NoError(t, err)
|
|
|
|
_, err = rand.Read(setID[:])
|
|
require.NoError(t, err)
|
|
|
|
_, err = rand.Read(payAddr[:])
|
|
require.NoError(t, err)
|
|
|
|
var totalAmt lnwire.MilliSatoshi = 1000
|
|
|
|
// Create an AMP shard tracker using the random data we just generated.
|
|
tracker := amp.NewShardTracker(root, setID, payAddr, totalAmt)
|
|
|
|
// Trying to retrieve a hash for id 0 should result in an error.
|
|
_, err = tracker.GetHash(0)
|
|
require.Error(t, err)
|
|
|
|
// We start by creating 20 shards.
|
|
const numShards = 20
|
|
|
|
var shards []shards.PaymentShard
|
|
for i := uint64(0); i < numShards; i++ {
|
|
s, err := tracker.NewShard(i, i == numShards-1)
|
|
require.NoError(t, err)
|
|
|
|
// Check that the shards have their payloads set as expected.
|
|
require.Equal(t, setID, s.AMP().SetID())
|
|
require.Equal(t, totalAmt, s.MPP().TotalMsat())
|
|
require.Equal(t, payAddr, s.MPP().PaymentAddr())
|
|
|
|
shards = append(shards, s)
|
|
}
|
|
|
|
// Make sure we can retrieve the hash for all of them.
|
|
for i := uint64(0); i < numShards; i++ {
|
|
hash, err := tracker.GetHash(i)
|
|
require.NoError(t, err)
|
|
require.Equal(t, shards[i].Hash(), hash)
|
|
}
|
|
|
|
// Now cancel half of the shards.
|
|
j := 0
|
|
for i := uint64(0); i < numShards; i++ {
|
|
if i%2 == 0 {
|
|
err := tracker.CancelShard(i)
|
|
require.NoError(t, err)
|
|
continue
|
|
}
|
|
|
|
// Keep shard.
|
|
shards[j] = shards[i]
|
|
j++
|
|
}
|
|
shards = shards[:j]
|
|
|
|
// Get a new last shard.
|
|
s, err := tracker.NewShard(numShards, true)
|
|
require.NoError(t, err)
|
|
shards = append(shards, s)
|
|
|
|
// Finally make sure these shards together can be used to reconstruct
|
|
// the children.
|
|
childDescs := make([]amp.ChildDesc, len(shards))
|
|
for i, s := range shards {
|
|
childDescs[i] = amp.ChildDesc{
|
|
Share: s.AMP().RootShare(),
|
|
Index: s.AMP().ChildIndex(),
|
|
}
|
|
}
|
|
|
|
// Using the child descriptors, reconstruct the children.
|
|
children := amp.ReconstructChildren(childDescs...)
|
|
|
|
// Validate that the derived child preimages match the hash of each shard.
|
|
for i, child := range children {
|
|
require.Equal(t, shards[i].Hash(), child.Hash)
|
|
}
|
|
}
|