You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
165 lines
4.2 KiB
165 lines
4.2 KiB
package amp |
|
|
|
import ( |
|
"crypto/rand" |
|
"encoding/binary" |
|
"fmt" |
|
"sync" |
|
|
|
"github.com/lightningnetwork/lnd/lntypes" |
|
"github.com/lightningnetwork/lnd/lnwire" |
|
"github.com/lightningnetwork/lnd/record" |
|
"github.com/lightningnetwork/lnd/routing/shards" |
|
) |
|
|
|
// Shard is an implementation of the shards.PaymentShards interface specific |
|
// to AMP payments. |
|
type Shard struct { |
|
child *Child |
|
mpp *record.MPP |
|
amp *record.AMP |
|
} |
|
|
|
// A compile time check to ensure Shard implements the shards.PaymentShard |
|
// interface. |
|
var _ shards.PaymentShard = (*Shard)(nil) |
|
|
|
// Hash returns the hash used for the HTLC representing this AMP shard. |
|
func (s *Shard) Hash() lntypes.Hash { |
|
return s.child.Hash |
|
} |
|
|
|
// MPP returns any extra MPP records that should be set for the final hop on |
|
// the route used by this shard. |
|
func (s *Shard) MPP() *record.MPP { |
|
return s.mpp |
|
} |
|
|
|
// AMP returns any extra AMP records that should be set for the final hop on |
|
// the route used by this shard. |
|
func (s *Shard) AMP() *record.AMP { |
|
return s.amp |
|
} |
|
|
|
// ShardTracker is an implementation of the shards.ShardTracker interface |
|
// that is able to generate payment shards according to the AMP splitting |
|
// algorithm. It can be used to generate new hashes to use for HTLCs, and also |
|
// cancel shares used for failed payment shards. |
|
type ShardTracker struct { |
|
setID [32]byte |
|
paymentAddr [32]byte |
|
totalAmt lnwire.MilliSatoshi |
|
|
|
sharer Sharer |
|
|
|
shards map[uint64]*Child |
|
sync.Mutex |
|
} |
|
|
|
// A compile time check to ensure ShardTracker implements the |
|
// shards.ShardTracker interface. |
|
var _ shards.ShardTracker = (*ShardTracker)(nil) |
|
|
|
// NewShardTracker creates a new shard tracker to use for AMP payments. The |
|
// root shard, setID, payment address and total amount must be correctly set in |
|
// order for the TLV options to include with each shard to be created |
|
// correctly. |
|
func NewShardTracker(root, setID, payAddr [32]byte, |
|
totalAmt lnwire.MilliSatoshi) *ShardTracker { |
|
|
|
// Create a new seed sharer from this root. |
|
rootShare := Share(root) |
|
rootSharer := SeedSharerFromRoot(&rootShare) |
|
|
|
return &ShardTracker{ |
|
setID: setID, |
|
paymentAddr: payAddr, |
|
totalAmt: totalAmt, |
|
sharer: rootSharer, |
|
shards: make(map[uint64]*Child), |
|
} |
|
} |
|
|
|
// NewShard registers a new attempt with the ShardTracker and returns a |
|
// new shard representing this attempt. This attempt's shard should be canceled |
|
// if it ends up not being used by the overall payment, i.e. if the attempt |
|
// fails. |
|
func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard, |
|
error) { |
|
|
|
s.Lock() |
|
defer s.Unlock() |
|
|
|
// Use a random child index. |
|
var childIndex [4]byte |
|
if _, err := rand.Read(childIndex[:]); err != nil { |
|
return nil, err |
|
} |
|
idx := binary.BigEndian.Uint32(childIndex[:]) |
|
|
|
// Depending on whether we are requesting the last shard or not, either |
|
// split the current share into two, or get a Child directly from the |
|
// current sharer. |
|
var child *Child |
|
if last { |
|
child = s.sharer.Child(idx) |
|
|
|
// If this was the last shard, set the current share to the |
|
// zero share to indicate we cannot split it further. |
|
s.sharer = s.sharer.Zero() |
|
} else { |
|
left, sharer, err := s.sharer.Split() |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
s.sharer = sharer |
|
child = left.Child(idx) |
|
} |
|
|
|
// Track the new child and return the shard. |
|
s.shards[pid] = child |
|
|
|
mpp := record.NewMPP(s.totalAmt, s.paymentAddr) |
|
amp := record.NewAMP( |
|
child.ChildDesc.Share, s.setID, child.ChildDesc.Index, |
|
) |
|
|
|
return &Shard{ |
|
child: child, |
|
mpp: mpp, |
|
amp: amp, |
|
}, nil |
|
} |
|
|
|
// CancelShard cancel's the shard corresponding to the given attempt ID. |
|
func (s *ShardTracker) CancelShard(pid uint64) error { |
|
s.Lock() |
|
defer s.Unlock() |
|
|
|
c, ok := s.shards[pid] |
|
if !ok { |
|
return fmt.Errorf("pid not found") |
|
} |
|
delete(s.shards, pid) |
|
|
|
// Now that we are canceling this shard, we XOR the share back into our |
|
// current share. |
|
s.sharer = s.sharer.Merge(c) |
|
return nil |
|
} |
|
|
|
// GetHash retrieves the hash used by the shard of the given attempt ID. This |
|
// will return an error if the attempt ID is unknown. |
|
func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) { |
|
s.Lock() |
|
defer s.Unlock() |
|
|
|
c, ok := s.shards[pid] |
|
if !ok { |
|
return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+ |
|
"not found", pid) |
|
} |
|
|
|
return c.Hash, nil |
|
}
|
|
|