amp: create amp.ShardTracker
We'll use this AMP-specific ShardTracker for AMP payments. It will be used to derive hashes for each HTLC attempt using the underlying AMP derivation scheme.
This commit is contained in:
parent
c1e82e534d
commit
2d397b12b1
165
amp/shard_tracker.go
Normal file
165
amp/shard_tracker.go
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/record"
|
||||||
|
"github.com/lightningnetwork/lnd/routing/shards"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Shard is an implementation of the shards.PaymentShards interface specific
|
||||||
|
// to AMP payments.
|
||||||
|
type Shard struct {
|
||||||
|
child *Child
|
||||||
|
mpp *record.MPP
|
||||||
|
amp *record.AMP
|
||||||
|
}
|
||||||
|
|
||||||
|
// A compile time check to ensure Shard implements the shards.PaymentShard
|
||||||
|
// interface.
|
||||||
|
var _ shards.PaymentShard = (*Shard)(nil)
|
||||||
|
|
||||||
|
// Hash returns the hash used for the HTLC representing this AMP shard.
|
||||||
|
func (s *Shard) Hash() lntypes.Hash {
|
||||||
|
return s.child.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// MPP returns any extra MPP records that should be set for the final hop on
|
||||||
|
// the route used by this shard.
|
||||||
|
func (s *Shard) MPP() *record.MPP {
|
||||||
|
return s.mpp
|
||||||
|
}
|
||||||
|
|
||||||
|
// AMP returns any extra AMP records that should be set for the final hop on
|
||||||
|
// the route used by this shard.
|
||||||
|
func (s *Shard) AMP() *record.AMP {
|
||||||
|
return s.amp
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShardTracker is an implementation of the shards.ShardTracker interface
|
||||||
|
// that is able to generate payment shards according to the AMP splitting
|
||||||
|
// algorithm. It can be used to generate new hashes to use for HTLCs, and also
|
||||||
|
// cancel shares used for failed payment shards.
|
||||||
|
type ShardTracker struct {
|
||||||
|
setID [32]byte
|
||||||
|
paymentAddr [32]byte
|
||||||
|
totalAmt lnwire.MilliSatoshi
|
||||||
|
|
||||||
|
sharer Sharer
|
||||||
|
|
||||||
|
shards map[uint64]*Child
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// A compile time check to ensure ShardTracker implements the
|
||||||
|
// shards.ShardTracker interface.
|
||||||
|
var _ shards.ShardTracker = (*ShardTracker)(nil)
|
||||||
|
|
||||||
|
// NewShardTracker creates a new shard tracker to use for AMP payments. The
|
||||||
|
// root shard, setID, payment address and total amount must be correctly set in
|
||||||
|
// order for the TLV options to include with each shard to be created
|
||||||
|
// correctly.
|
||||||
|
func NewShardTracker(root, setID, payAddr [32]byte,
|
||||||
|
totalAmt lnwire.MilliSatoshi) *ShardTracker {
|
||||||
|
|
||||||
|
// Create a new seed sharer from this root.
|
||||||
|
rootShare := Share(root)
|
||||||
|
rootSharer := SeedSharerFromRoot(&rootShare)
|
||||||
|
|
||||||
|
return &ShardTracker{
|
||||||
|
setID: setID,
|
||||||
|
paymentAddr: payAddr,
|
||||||
|
totalAmt: totalAmt,
|
||||||
|
sharer: rootSharer,
|
||||||
|
shards: make(map[uint64]*Child),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewShard registers a new attempt with the ShardTracker and returns a
|
||||||
|
// new shard representing this attempt. This attempt's shard should be canceled
|
||||||
|
// if it ends up not being used by the overall payment, i.e. if the attempt
|
||||||
|
// fails.
|
||||||
|
func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard,
|
||||||
|
error) {
|
||||||
|
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
// Use a random child index.
|
||||||
|
var childIndex [4]byte
|
||||||
|
if _, err := rand.Read(childIndex[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
idx := binary.BigEndian.Uint32(childIndex[:])
|
||||||
|
|
||||||
|
// Depending on whether we are requesting the last shard or not, either
|
||||||
|
// split the current share into two, or get a Child directly from the
|
||||||
|
// current sharer.
|
||||||
|
var child *Child
|
||||||
|
if last {
|
||||||
|
child = s.sharer.Child(idx)
|
||||||
|
|
||||||
|
// If this was the last shard, set the current share to the
|
||||||
|
// zero share to indicate we cannot split it further.
|
||||||
|
s.sharer = s.sharer.Zero()
|
||||||
|
} else {
|
||||||
|
left, sharer, err := s.sharer.Split()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sharer = sharer
|
||||||
|
child = left.Child(idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track the new child and return the shard.
|
||||||
|
s.shards[pid] = child
|
||||||
|
|
||||||
|
mpp := record.NewMPP(s.totalAmt, s.paymentAddr)
|
||||||
|
amp := record.NewAMP(
|
||||||
|
child.ChildDesc.Share, s.setID, child.ChildDesc.Index,
|
||||||
|
)
|
||||||
|
|
||||||
|
return &Shard{
|
||||||
|
child: child,
|
||||||
|
mpp: mpp,
|
||||||
|
amp: amp,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelShard cancel's the shard corresponding to the given attempt ID.
|
||||||
|
func (s *ShardTracker) CancelShard(pid uint64) error {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
c, ok := s.shards[pid]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("pid not found")
|
||||||
|
}
|
||||||
|
delete(s.shards, pid)
|
||||||
|
|
||||||
|
// Now that we are canceling this shard, we XOR the share back into our
|
||||||
|
// current share.
|
||||||
|
s.sharer = s.sharer.Merge(c)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHash retrieves the hash used by the shard of the given attempt ID. This
|
||||||
|
// will return an error if the attempt ID is unknown.
|
||||||
|
func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) {
|
||||||
|
s.Lock()
|
||||||
|
defer s.Unlock()
|
||||||
|
|
||||||
|
c, ok := s.shards[pid]
|
||||||
|
if !ok {
|
||||||
|
return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+
|
||||||
|
"not found", pid)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Hash, nil
|
||||||
|
}
|
95
amp/shard_tracker_test.go
Normal file
95
amp/shard_tracker_test.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
package amp_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/amp"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
|
"github.com/lightningnetwork/lnd/routing/shards"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAMPShardTracker tests that we can derive and cancel shards at will using
|
||||||
|
// the AMP shard tracker.
|
||||||
|
func TestAMPShardTracker(t *testing.T) {
|
||||||
|
var root, setID, payAddr [32]byte
|
||||||
|
_, err := rand.Read(root[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = rand.Read(setID[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = rand.Read(payAddr[:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var totalAmt lnwire.MilliSatoshi = 1000
|
||||||
|
|
||||||
|
// Create an AMP shard tracker using the random data we just generated.
|
||||||
|
tracker := amp.NewShardTracker(root, setID, payAddr, totalAmt)
|
||||||
|
|
||||||
|
// Trying to retrieve a hash for id 0 should result in an error.
|
||||||
|
_, err = tracker.GetHash(0)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// We start by creating 20 shards.
|
||||||
|
const numShards = 20
|
||||||
|
|
||||||
|
var shards []shards.PaymentShard
|
||||||
|
for i := uint64(0); i < numShards; i++ {
|
||||||
|
s, err := tracker.NewShard(i, i == numShards-1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check that the shards have their payloads set as expected.
|
||||||
|
require.Equal(t, setID, s.AMP().SetID())
|
||||||
|
require.Equal(t, totalAmt, s.MPP().TotalMsat())
|
||||||
|
require.Equal(t, payAddr, s.MPP().PaymentAddr())
|
||||||
|
|
||||||
|
shards = append(shards, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we can retrieve the hash for all of them.
|
||||||
|
for i := uint64(0); i < numShards; i++ {
|
||||||
|
hash, err := tracker.GetHash(i)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, shards[i].Hash(), hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now cancel half of the shards.
|
||||||
|
j := 0
|
||||||
|
for i := uint64(0); i < numShards; i++ {
|
||||||
|
if i%2 == 0 {
|
||||||
|
err := tracker.CancelShard(i)
|
||||||
|
require.NoError(t, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep shard.
|
||||||
|
shards[j] = shards[i]
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
shards = shards[:j]
|
||||||
|
|
||||||
|
// Get a new last shard.
|
||||||
|
s, err := tracker.NewShard(numShards, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
shards = append(shards, s)
|
||||||
|
|
||||||
|
// Finally make sure these shards together can be used to reconstruct
|
||||||
|
// the children.
|
||||||
|
childDescs := make([]amp.ChildDesc, len(shards))
|
||||||
|
for i, s := range shards {
|
||||||
|
childDescs[i] = amp.ChildDesc{
|
||||||
|
Share: s.AMP().RootShare(),
|
||||||
|
Index: s.AMP().ChildIndex(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Using the child descriptors, reconstruct the children.
|
||||||
|
children := amp.ReconstructChildren(childDescs...)
|
||||||
|
|
||||||
|
// Validate that the derived child preimages match the hash of each shard.
|
||||||
|
for i, child := range children {
|
||||||
|
require.Equal(t, shards[i].Hash(), child.Hash)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user