package amp_test import ( "crypto/rand" "testing" "github.com/lightningnetwork/lnd/amp" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/shards" "github.com/stretchr/testify/require" ) // TestAMPShardTracker tests that we can derive and cancel shards at will using // the AMP shard tracker. func TestAMPShardTracker(t *testing.T) { var root, setID, payAddr [32]byte _, err := rand.Read(root[:]) require.NoError(t, err) _, err = rand.Read(setID[:]) require.NoError(t, err) _, err = rand.Read(payAddr[:]) require.NoError(t, err) var totalAmt lnwire.MilliSatoshi = 1000 // Create an AMP shard tracker using the random data we just generated. tracker := amp.NewShardTracker(root, setID, payAddr, totalAmt) // Trying to retrieve a hash for id 0 should result in an error. _, err = tracker.GetHash(0) require.Error(t, err) // We start by creating 20 shards. const numShards = 20 var shards []shards.PaymentShard for i := uint64(0); i < numShards; i++ { s, err := tracker.NewShard(i, i == numShards-1) require.NoError(t, err) // Check that the shards have their payloads set as expected. require.Equal(t, setID, s.AMP().SetID()) require.Equal(t, totalAmt, s.MPP().TotalMsat()) require.Equal(t, payAddr, s.MPP().PaymentAddr()) shards = append(shards, s) } // Make sure we can retrieve the hash for all of them. for i := uint64(0); i < numShards; i++ { hash, err := tracker.GetHash(i) require.NoError(t, err) require.Equal(t, shards[i].Hash(), hash) } // Now cancel half of the shards. j := 0 for i := uint64(0); i < numShards; i++ { if i%2 == 0 { err := tracker.CancelShard(i) require.NoError(t, err) continue } // Keep shard. shards[j] = shards[i] j++ } shards = shards[:j] // Get a new last shard. s, err := tracker.NewShard(numShards, true) require.NoError(t, err) shards = append(shards, s) // Finally make sure these shards together can be used to reconstruct // the children. childDescs := make([]amp.ChildDesc, len(shards)) for i, s := range shards { childDescs[i] = amp.ChildDesc{ Share: s.AMP().RootShare(), Index: s.AMP().ChildIndex(), } } // Using the child descriptors, reconstruct the children. children := amp.ReconstructChildren(childDescs...) // Validate that the derived child preimages match the hash of each shard. for i, child := range children { require.Equal(t, shards[i].Hash(), child.Hash) } }