routing/shards: add ShardTracker interface
We'll use this to keep track of the outstanding shards and which preimages we are using for each. For now this is a simple map from attempt ID to hash, but later we'll hide the AMP child derivation behind this interface.
This commit is contained in:
parent
a9f19b100b
commit
6474b253d6
126
routing/shards/shard_tracker.go
Normal file
126
routing/shards/shard_tracker.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package shards
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
|
"github.com/lightningnetwork/lnd/record"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PaymentShard is an interface representing a shard tracked by the
|
||||||
|
// ShardTracker. It contains options that are specific to the given shard that
|
||||||
|
// might differ from the overall payment.
|
||||||
|
type PaymentShard interface {
|
||||||
|
// Hash returns the hash used for the HTLC representing this shard.
|
||||||
|
Hash() lntypes.Hash
|
||||||
|
|
||||||
|
// MPP returns any extra MPP records that should be set for the final
|
||||||
|
// hop on the route used by this shard.
|
||||||
|
MPP() *record.MPP
|
||||||
|
|
||||||
|
// AMP returns any extra AMP records that should be set for the final
|
||||||
|
// hop on the route used by this shard.
|
||||||
|
AMP() *record.AMP
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShardTracker is an interfae representing a tracker that keeps track of the
|
||||||
|
// inflight shards of a payment, and is able to assign new shards the correct
|
||||||
|
// options such as hash and extra records.
|
||||||
|
type ShardTracker interface {
|
||||||
|
// NewShard registers a new attempt with the ShardTracker and returns a
|
||||||
|
// new shard representing this attempt. This attempt's shard should be
|
||||||
|
// canceled if it ends up not being used by the overall payment, i.e.
|
||||||
|
// if the attempt fails.
|
||||||
|
NewShard(uint64, bool) (PaymentShard, error)
|
||||||
|
|
||||||
|
// CancelShard cancel's the shard corresponding to the given attempt
|
||||||
|
// ID. This lets the ShardTracker free up any slots used by this shard,
|
||||||
|
// and in case of AMP payments return the share used by this shard to
|
||||||
|
// the root share.
|
||||||
|
CancelShard(uint64) error
|
||||||
|
|
||||||
|
// GetHash retrieves the hash used by the shard of the given attempt
|
||||||
|
// ID. This wil return an error if the attempt ID is unknown.
|
||||||
|
GetHash(uint64) (lntypes.Hash, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shard is a struct used for simple shards where we obly need to keep map it
|
||||||
|
// to a single hash.
|
||||||
|
type Shard struct {
|
||||||
|
hash lntypes.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash returns the hash used for the HTLC representing this shard.
|
||||||
|
func (s *Shard) Hash() lntypes.Hash {
|
||||||
|
return s.hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// MPP returns any extra MPP records that should be set for the final hop on
|
||||||
|
// the route used by this shard.
|
||||||
|
func (s *Shard) MPP() *record.MPP {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AMP returns any extra AMP records that should be set for the final hop on
|
||||||
|
// the route used by this shard.
|
||||||
|
func (s *Shard) AMP() *record.AMP {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SimpleShardTracker is an implementation of the ShardTracker interface that
|
||||||
|
// simply maps attempt IDs to hashes. New shards will be given a static payment
|
||||||
|
// hash. This should be used for regular and MPP payments, in addition to
|
||||||
|
// resumed payments where all the attempt's hashes have already been created.
|
||||||
|
type SimpleShardTracker struct {
|
||||||
|
hash lntypes.Hash
|
||||||
|
shards map[uint64]lntypes.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// A compile time check to ensure SimpleShardTracker implements the
|
||||||
|
// ShardTracker interface.
|
||||||
|
var _ ShardTracker = (*SimpleShardTracker)(nil)
|
||||||
|
|
||||||
|
// NewSimpleShardTracker creates a new intance of the SimpleShardTracker with
|
||||||
|
// the given payment hash and existing attempts.
|
||||||
|
func NewSimpleShardTracker(paymentHash lntypes.Hash,
|
||||||
|
shards map[uint64]lntypes.Hash) ShardTracker {
|
||||||
|
|
||||||
|
if shards == nil {
|
||||||
|
shards = make(map[uint64]lntypes.Hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SimpleShardTracker{
|
||||||
|
hash: paymentHash,
|
||||||
|
shards: shards,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewShard registers a new attempt with the ShardTracker and returns a
|
||||||
|
// new shard representing this attempt. This attempt's shard should be canceled
|
||||||
|
// if it ends up not being used by the overall payment, i.e. if the attempt
|
||||||
|
// fails.
|
||||||
|
func (m *SimpleShardTracker) NewShard(id uint64, _ bool) (PaymentShard, error) {
|
||||||
|
m.shards[id] = m.hash
|
||||||
|
|
||||||
|
return &Shard{
|
||||||
|
hash: m.hash,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelShard cancel's the shard corresponding to the given attempt ID.
|
||||||
|
func (m *SimpleShardTracker) CancelShard(id uint64) error {
|
||||||
|
delete(m.shards, id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHash retrieves the hash used by the shard of the given attempt ID. This
|
||||||
|
// will return an error if the attempt ID is unknown.
|
||||||
|
func (m *SimpleShardTracker) GetHash(id uint64) (lntypes.Hash, error) {
|
||||||
|
hash, ok := m.shards[id]
|
||||||
|
if !ok {
|
||||||
|
return lntypes.Hash{}, fmt.Errorf("hash for attempt id %v "+
|
||||||
|
"not found", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hash, nil
|
||||||
|
}
|
47
routing/shards/shard_tracker_test.go
Normal file
47
routing/shards/shard_tracker_test.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package shards_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
|
"github.com/lightningnetwork/lnd/routing/shards"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestSimpleShardTracker tests that the simple tracker that keeps a map from
|
||||||
|
// attemptID-> payment hash works.
|
||||||
|
func TestSimpleShardTracker(t *testing.T) {
|
||||||
|
var testHashes [2]lntypes.Hash
|
||||||
|
for i := range testHashes {
|
||||||
|
_, err := rand.Read(testHashes[i][:])
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
startAttempts := map[uint64]lntypes.Hash{
|
||||||
|
1: testHashes[1],
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker := shards.NewSimpleShardTracker(testHashes[0], startAttempts)
|
||||||
|
|
||||||
|
// Trying to retrieve a hash for id 0 should result in an error.
|
||||||
|
_, err := tracker.GetHash(0)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Getting id 1 should workd.
|
||||||
|
hash1, err := tracker.GetHash(1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, testHashes[1], hash1)
|
||||||
|
|
||||||
|
// Finally, create a new shard and immediately retrieve the hash.
|
||||||
|
shard, err := tracker.NewShard(2, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// It's hash should be the tracker's overall payment hash.
|
||||||
|
hash2, err := tracker.GetHash(2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, testHashes[0], shard.Hash())
|
||||||
|
require.Equal(t, testHashes[0], hash2)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user