Browse Source
We'll use this AMP-specific ShardTracker for AMP payments. It will be used to derive hashes for each HTLC attempt using the underlying AMP derivation scheme.master
Johan T. Halseth
3 years ago
2 changed files with 260 additions and 0 deletions
@ -0,0 +1,165 @@
|
||||
package amp |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"encoding/binary" |
||||
"fmt" |
||||
"sync" |
||||
|
||||
"github.com/lightningnetwork/lnd/lntypes" |
||||
"github.com/lightningnetwork/lnd/lnwire" |
||||
"github.com/lightningnetwork/lnd/record" |
||||
"github.com/lightningnetwork/lnd/routing/shards" |
||||
) |
||||
|
||||
// Shard is an implementation of the shards.PaymentShards interface specific
|
||||
// to AMP payments.
|
||||
type Shard struct { |
||||
child *Child |
||||
mpp *record.MPP |
||||
amp *record.AMP |
||||
} |
||||
|
||||
// A compile time check to ensure Shard implements the shards.PaymentShard
|
||||
// interface.
|
||||
var _ shards.PaymentShard = (*Shard)(nil) |
||||
|
||||
// Hash returns the hash used for the HTLC representing this AMP shard.
|
||||
func (s *Shard) Hash() lntypes.Hash { |
||||
return s.child.Hash |
||||
} |
||||
|
||||
// MPP returns any extra MPP records that should be set for the final hop on
|
||||
// the route used by this shard.
|
||||
func (s *Shard) MPP() *record.MPP { |
||||
return s.mpp |
||||
} |
||||
|
||||
// AMP returns any extra AMP records that should be set for the final hop on
|
||||
// the route used by this shard.
|
||||
func (s *Shard) AMP() *record.AMP { |
||||
return s.amp |
||||
} |
||||
|
||||
// ShardTracker is an implementation of the shards.ShardTracker interface
|
||||
// that is able to generate payment shards according to the AMP splitting
|
||||
// algorithm. It can be used to generate new hashes to use for HTLCs, and also
|
||||
// cancel shares used for failed payment shards.
|
||||
type ShardTracker struct { |
||||
setID [32]byte |
||||
paymentAddr [32]byte |
||||
totalAmt lnwire.MilliSatoshi |
||||
|
||||
sharer Sharer |
||||
|
||||
shards map[uint64]*Child |
||||
sync.Mutex |
||||
} |
||||
|
||||
// A compile time check to ensure ShardTracker implements the
|
||||
// shards.ShardTracker interface.
|
||||
var _ shards.ShardTracker = (*ShardTracker)(nil) |
||||
|
||||
// NewShardTracker creates a new shard tracker to use for AMP payments. The
|
||||
// root shard, setID, payment address and total amount must be correctly set in
|
||||
// order for the TLV options to include with each shard to be created
|
||||
// correctly.
|
||||
func NewShardTracker(root, setID, payAddr [32]byte, |
||||
totalAmt lnwire.MilliSatoshi) *ShardTracker { |
||||
|
||||
// Create a new seed sharer from this root.
|
||||
rootShare := Share(root) |
||||
rootSharer := SeedSharerFromRoot(&rootShare) |
||||
|
||||
return &ShardTracker{ |
||||
setID: setID, |
||||
paymentAddr: payAddr, |
||||
totalAmt: totalAmt, |
||||
sharer: rootSharer, |
||||
shards: make(map[uint64]*Child), |
||||
} |
||||
} |
||||
|
||||
// NewShard registers a new attempt with the ShardTracker and returns a
|
||||
// new shard representing this attempt. This attempt's shard should be canceled
|
||||
// if it ends up not being used by the overall payment, i.e. if the attempt
|
||||
// fails.
|
||||
func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard, |
||||
error) { |
||||
|
||||
s.Lock() |
||||
defer s.Unlock() |
||||
|
||||
// Use a random child index.
|
||||
var childIndex [4]byte |
||||
if _, err := rand.Read(childIndex[:]); err != nil { |
||||
return nil, err |
||||
} |
||||
idx := binary.BigEndian.Uint32(childIndex[:]) |
||||
|
||||
// Depending on whether we are requesting the last shard or not, either
|
||||
// split the current share into two, or get a Child directly from the
|
||||
// current sharer.
|
||||
var child *Child |
||||
if last { |
||||
child = s.sharer.Child(idx) |
||||
|
||||
// If this was the last shard, set the current share to the
|
||||
// zero share to indicate we cannot split it further.
|
||||
s.sharer = s.sharer.Zero() |
||||
} else { |
||||
left, sharer, err := s.sharer.Split() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
s.sharer = sharer |
||||
child = left.Child(idx) |
||||
} |
||||
|
||||
// Track the new child and return the shard.
|
||||
s.shards[pid] = child |
||||
|
||||
mpp := record.NewMPP(s.totalAmt, s.paymentAddr) |
||||
amp := record.NewAMP( |
||||
child.ChildDesc.Share, s.setID, child.ChildDesc.Index, |
||||
) |
||||
|
||||
return &Shard{ |
||||
child: child, |
||||
mpp: mpp, |
||||
amp: amp, |
||||
}, nil |
||||
} |
||||
|
||||
// CancelShard cancel's the shard corresponding to the given attempt ID.
|
||||
func (s *ShardTracker) CancelShard(pid uint64) error { |
||||
s.Lock() |
||||
defer s.Unlock() |
||||
|
||||
c, ok := s.shards[pid] |
||||
if !ok { |
||||
return fmt.Errorf("pid not found") |
||||
} |
||||
delete(s.shards, pid) |
||||
|
||||
// Now that we are canceling this shard, we XOR the share back into our
|
||||
// current share.
|
||||
s.sharer = s.sharer.Merge(c) |
||||
return nil |
||||
} |
||||
|
||||
// GetHash retrieves the hash used by the shard of the given attempt ID. This
|
||||
// will return an error if the attempt ID is unknown.
|
||||
func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) { |
||||
s.Lock() |
||||
defer s.Unlock() |
||||
|
||||
c, ok := s.shards[pid] |
||||
if !ok { |
||||
return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+ |
||||
"not found", pid) |
||||
} |
||||
|
||||
return c.Hash, nil |
||||
} |
@ -0,0 +1,95 @@
|
||||
package amp_test |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"testing" |
||||
|
||||
"github.com/lightningnetwork/lnd/amp" |
||||
"github.com/lightningnetwork/lnd/lnwire" |
||||
"github.com/lightningnetwork/lnd/routing/shards" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
// TestAMPShardTracker tests that we can derive and cancel shards at will using
|
||||
// the AMP shard tracker.
|
||||
func TestAMPShardTracker(t *testing.T) { |
||||
var root, setID, payAddr [32]byte |
||||
_, err := rand.Read(root[:]) |
||||
require.NoError(t, err) |
||||
|
||||
_, err = rand.Read(setID[:]) |
||||
require.NoError(t, err) |
||||
|
||||
_, err = rand.Read(payAddr[:]) |
||||
require.NoError(t, err) |
||||
|
||||
var totalAmt lnwire.MilliSatoshi = 1000 |
||||
|
||||
// Create an AMP shard tracker using the random data we just generated.
|
||||
tracker := amp.NewShardTracker(root, setID, payAddr, totalAmt) |
||||
|
||||
// Trying to retrieve a hash for id 0 should result in an error.
|
||||
_, err = tracker.GetHash(0) |
||||
require.Error(t, err) |
||||
|
||||
// We start by creating 20 shards.
|
||||
const numShards = 20 |
||||
|
||||
var shards []shards.PaymentShard |
||||
for i := uint64(0); i < numShards; i++ { |
||||
s, err := tracker.NewShard(i, i == numShards-1) |
||||
require.NoError(t, err) |
||||
|
||||
// Check that the shards have their payloads set as expected.
|
||||
require.Equal(t, setID, s.AMP().SetID()) |
||||
require.Equal(t, totalAmt, s.MPP().TotalMsat()) |
||||
require.Equal(t, payAddr, s.MPP().PaymentAddr()) |
||||
|
||||
shards = append(shards, s) |
||||
} |
||||
|
||||
// Make sure we can retrieve the hash for all of them.
|
||||
for i := uint64(0); i < numShards; i++ { |
||||
hash, err := tracker.GetHash(i) |
||||
require.NoError(t, err) |
||||
require.Equal(t, shards[i].Hash(), hash) |
||||
} |
||||
|
||||
// Now cancel half of the shards.
|
||||
j := 0 |
||||
for i := uint64(0); i < numShards; i++ { |
||||
if i%2 == 0 { |
||||
err := tracker.CancelShard(i) |
||||
require.NoError(t, err) |
||||
continue |
||||
} |
||||
|
||||
// Keep shard.
|
||||
shards[j] = shards[i] |
||||
j++ |
||||
} |
||||
shards = shards[:j] |
||||
|
||||
// Get a new last shard.
|
||||
s, err := tracker.NewShard(numShards, true) |
||||
require.NoError(t, err) |
||||
shards = append(shards, s) |
||||
|
||||
// Finally make sure these shards together can be used to reconstruct
|
||||
// the children.
|
||||
childDescs := make([]amp.ChildDesc, len(shards)) |
||||
for i, s := range shards { |
||||
childDescs[i] = amp.ChildDesc{ |
||||
Share: s.AMP().RootShare(), |
||||
Index: s.AMP().ChildIndex(), |
||||
} |
||||
} |
||||
|
||||
// Using the child descriptors, reconstruct the children.
|
||||
children := amp.ReconstructChildren(childDescs...) |
||||
|
||||
// Validate that the derived child preimages match the hash of each shard.
|
||||
for i, child := range children { |
||||
require.Equal(t, shards[i].Hash(), child.Hash) |
||||
} |
||||
} |
Loading…
Reference in new issue