From 2d397b12b1dc0b010e1a4fc64b2c2d627ea70c3f Mon Sep 17 00:00:00 2001 From: "Johan T. Halseth" Date: Mon, 12 Apr 2021 15:05:01 +0200 Subject: [PATCH] amp: create amp.ShardTracker We'll use this AMP-specific ShardTracker for AMP payments. It will be used to derive hashes for each HTLC attempt using the underlying AMP derivation scheme. --- amp/shard_tracker.go | 165 ++++++++++++++++++++++++++++++++++++++ amp/shard_tracker_test.go | 95 ++++++++++++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 amp/shard_tracker.go create mode 100644 amp/shard_tracker_test.go diff --git a/amp/shard_tracker.go b/amp/shard_tracker.go new file mode 100644 index 00000000..473447e4 --- /dev/null +++ b/amp/shard_tracker.go @@ -0,0 +1,165 @@ +package amp + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "sync" + + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/routing/shards" +) + +// Shard is an implementation of the shards.PaymentShards interface specific +// to AMP payments. +type Shard struct { + child *Child + mpp *record.MPP + amp *record.AMP +} + +// A compile time check to ensure Shard implements the shards.PaymentShard +// interface. +var _ shards.PaymentShard = (*Shard)(nil) + +// Hash returns the hash used for the HTLC representing this AMP shard. +func (s *Shard) Hash() lntypes.Hash { + return s.child.Hash +} + +// MPP returns any extra MPP records that should be set for the final hop on +// the route used by this shard. +func (s *Shard) MPP() *record.MPP { + return s.mpp +} + +// AMP returns any extra AMP records that should be set for the final hop on +// the route used by this shard. +func (s *Shard) AMP() *record.AMP { + return s.amp +} + +// ShardTracker is an implementation of the shards.ShardTracker interface +// that is able to generate payment shards according to the AMP splitting +// algorithm. It can be used to generate new hashes to use for HTLCs, and also +// cancel shares used for failed payment shards. +type ShardTracker struct { + setID [32]byte + paymentAddr [32]byte + totalAmt lnwire.MilliSatoshi + + sharer Sharer + + shards map[uint64]*Child + sync.Mutex +} + +// A compile time check to ensure ShardTracker implements the +// shards.ShardTracker interface. +var _ shards.ShardTracker = (*ShardTracker)(nil) + +// NewShardTracker creates a new shard tracker to use for AMP payments. The +// root shard, setID, payment address and total amount must be correctly set in +// order for the TLV options to include with each shard to be created +// correctly. +func NewShardTracker(root, setID, payAddr [32]byte, + totalAmt lnwire.MilliSatoshi) *ShardTracker { + + // Create a new seed sharer from this root. + rootShare := Share(root) + rootSharer := SeedSharerFromRoot(&rootShare) + + return &ShardTracker{ + setID: setID, + paymentAddr: payAddr, + totalAmt: totalAmt, + sharer: rootSharer, + shards: make(map[uint64]*Child), + } +} + +// NewShard registers a new attempt with the ShardTracker and returns a +// new shard representing this attempt. This attempt's shard should be canceled +// if it ends up not being used by the overall payment, i.e. if the attempt +// fails. +func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard, + error) { + + s.Lock() + defer s.Unlock() + + // Use a random child index. + var childIndex [4]byte + if _, err := rand.Read(childIndex[:]); err != nil { + return nil, err + } + idx := binary.BigEndian.Uint32(childIndex[:]) + + // Depending on whether we are requesting the last shard or not, either + // split the current share into two, or get a Child directly from the + // current sharer. + var child *Child + if last { + child = s.sharer.Child(idx) + + // If this was the last shard, set the current share to the + // zero share to indicate we cannot split it further. + s.sharer = s.sharer.Zero() + } else { + left, sharer, err := s.sharer.Split() + if err != nil { + return nil, err + } + + s.sharer = sharer + child = left.Child(idx) + } + + // Track the new child and return the shard. + s.shards[pid] = child + + mpp := record.NewMPP(s.totalAmt, s.paymentAddr) + amp := record.NewAMP( + child.ChildDesc.Share, s.setID, child.ChildDesc.Index, + ) + + return &Shard{ + child: child, + mpp: mpp, + amp: amp, + }, nil +} + +// CancelShard cancel's the shard corresponding to the given attempt ID. +func (s *ShardTracker) CancelShard(pid uint64) error { + s.Lock() + defer s.Unlock() + + c, ok := s.shards[pid] + if !ok { + return fmt.Errorf("pid not found") + } + delete(s.shards, pid) + + // Now that we are canceling this shard, we XOR the share back into our + // current share. + s.sharer = s.sharer.Merge(c) + return nil +} + +// GetHash retrieves the hash used by the shard of the given attempt ID. This +// will return an error if the attempt ID is unknown. +func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) { + s.Lock() + defer s.Unlock() + + c, ok := s.shards[pid] + if !ok { + return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+ + "not found", pid) + } + + return c.Hash, nil +} diff --git a/amp/shard_tracker_test.go b/amp/shard_tracker_test.go new file mode 100644 index 00000000..4f0ca982 --- /dev/null +++ b/amp/shard_tracker_test.go @@ -0,0 +1,95 @@ +package amp_test + +import ( + "crypto/rand" + "testing" + + "github.com/lightningnetwork/lnd/amp" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/shards" + "github.com/stretchr/testify/require" +) + +// TestAMPShardTracker tests that we can derive and cancel shards at will using +// the AMP shard tracker. +func TestAMPShardTracker(t *testing.T) { + var root, setID, payAddr [32]byte + _, err := rand.Read(root[:]) + require.NoError(t, err) + + _, err = rand.Read(setID[:]) + require.NoError(t, err) + + _, err = rand.Read(payAddr[:]) + require.NoError(t, err) + + var totalAmt lnwire.MilliSatoshi = 1000 + + // Create an AMP shard tracker using the random data we just generated. + tracker := amp.NewShardTracker(root, setID, payAddr, totalAmt) + + // Trying to retrieve a hash for id 0 should result in an error. + _, err = tracker.GetHash(0) + require.Error(t, err) + + // We start by creating 20 shards. + const numShards = 20 + + var shards []shards.PaymentShard + for i := uint64(0); i < numShards; i++ { + s, err := tracker.NewShard(i, i == numShards-1) + require.NoError(t, err) + + // Check that the shards have their payloads set as expected. + require.Equal(t, setID, s.AMP().SetID()) + require.Equal(t, totalAmt, s.MPP().TotalMsat()) + require.Equal(t, payAddr, s.MPP().PaymentAddr()) + + shards = append(shards, s) + } + + // Make sure we can retrieve the hash for all of them. + for i := uint64(0); i < numShards; i++ { + hash, err := tracker.GetHash(i) + require.NoError(t, err) + require.Equal(t, shards[i].Hash(), hash) + } + + // Now cancel half of the shards. + j := 0 + for i := uint64(0); i < numShards; i++ { + if i%2 == 0 { + err := tracker.CancelShard(i) + require.NoError(t, err) + continue + } + + // Keep shard. + shards[j] = shards[i] + j++ + } + shards = shards[:j] + + // Get a new last shard. + s, err := tracker.NewShard(numShards, true) + require.NoError(t, err) + shards = append(shards, s) + + // Finally make sure these shards together can be used to reconstruct + // the children. + childDescs := make([]amp.ChildDesc, len(shards)) + for i, s := range shards { + childDescs[i] = amp.ChildDesc{ + Share: s.AMP().RootShare(), + Index: s.AMP().ChildIndex(), + } + } + + // Using the child descriptors, reconstruct the children. + children := amp.ReconstructChildren(childDescs...) + + // Validate that the derived child preimages match the hash of each shard. + for i, child := range children { + require.Equal(t, shards[i].Hash(), child.Hash) + } +}