htlcswitch: obtain the min final CLTV delta directly from the InvoiceDatabase

In this commit, we modify the existing logic that would attempt to read
the min CLTV information from the invoice directly. With this route, we
avoid any sort of DB index modifications, as this information is already
stored within the payment request, which is already available to the
outside callers. By modifying the InvoiceDatabase interface, we avoid
having to make the switch aware of what the "primary" chain is.
This commit is contained in:
Olaoluwa Osuntokun 2018-06-29 16:02:59 -07:00
parent c914de177b
commit 2196d9375e
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21
6 changed files with 72 additions and 54 deletions

@ -11,8 +11,10 @@ import (
// which may search, lookup and settle invoices.
type InvoiceDatabase interface {
// LookupInvoice attempts to look up an invoice according to its 32
// byte payment hash.
LookupInvoice(chainhash.Hash) (channeldb.Invoice, error)
// byte payment hash. This method should also reutrn the min final CLTV
// delta for this invoice. We'll use this to ensure that the HTLC
// extended to us gives us enough time to settle as we prescribe.
LookupInvoice(chainhash.Hash) (channeldb.Invoice, uint32, error)
// SettleInvoice attempts to mark an invoice corresponding to the
// passed payment hash as fully settled.

@ -2167,7 +2167,9 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
// we attempt to see if we have an invoice locally
// which'll allow us to settle this htlc.
invoiceHash := chainhash.Hash(pd.RHash)
invoice, err := l.cfg.Registry.LookupInvoice(invoiceHash)
invoice, minCltvDelta, err := l.cfg.Registry.LookupInvoice(
invoiceHash,
)
if err != nil {
log.Errorf("unable to query invoice registry: "+
" %v", err)
@ -2258,7 +2260,6 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
// We'll also ensure that our time-lock value has been
// computed correctly.
minCltvDelta := uint32(invoice.Terms.FinalCltvDelta)
expectedHeight := heightNow + minCltvDelta
switch {

@ -227,7 +227,7 @@ func TestChannelLinkSingleHopPayment(t *testing.T) {
// Check that alice invoice was settled and bandwidth of HTLC
// links was changed.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -452,7 +452,7 @@ func TestChannelLinkMultiHopPayment(t *testing.T) {
// Check that Carol invoice was settled and bandwidth of HTLC
// links were changed.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -793,7 +793,7 @@ func TestUpdateForwardingPolicy(t *testing.T) {
// Carol's invoice should now be shown as settled as the payment
// succeeded.
invoice, err := n.carolServer.registry.LookupInvoice(payResp)
invoice, _, err := n.carolServer.registry.LookupInvoice(payResp)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -909,7 +909,7 @@ func TestChannelLinkMultiHopInsufficientPayment(t *testing.T) {
// Check that alice invoice wasn't settled and bandwidth of htlc
// links hasn't been changed.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -1056,7 +1056,9 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) {
htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight,
n.firstBobChannelLink, n.carolChannelLink)
daveServer, err := newMockServer(t, "dave", testStartingHeight, nil)
daveServer, err := newMockServer(
t, "dave", testStartingHeight, nil, n.globalPolicy.TimeLockDelta,
)
if err != nil {
t.Fatalf("unable to init dave's server: %v", err)
}
@ -1077,7 +1079,7 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) {
// Check that alice invoice wasn't settled and bandwidth of htlc
// links hasn't been changed.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -1165,7 +1167,7 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) {
// Check that alice invoice wasn't settled and bandwidth of htlc
// links hasn't been changed.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -1442,10 +1444,9 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
}
var (
invoiceRegistry = newMockRegistry()
decoder = newMockIteratorDecoder()
obfuscator = NewMockObfuscator()
alicePeer = &mockPeer{
decoder = newMockIteratorDecoder()
obfuscator = NewMockObfuscator()
alicePeer = &mockPeer{
sentMsgs: make(chan lnwire.Message, 2000),
quit: make(chan struct{}),
}
@ -1454,6 +1455,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
BaseFee: lnwire.NewMSatFromSatoshis(1),
TimeLockDelta: 6,
}
invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta)
)
pCache := &mockPreimageCache{
@ -3257,7 +3259,7 @@ func TestChannelRetransmission(t *testing.T) {
// Check that alice invoice wasn't settled and
// bandwidth of htlc links hasn't been changed.
invoice, err = receiver.registry.LookupInvoice(rhash)
invoice, _, err = receiver.registry.LookupInvoice(rhash)
if err != nil {
err = errors.Errorf("unable to get invoice: %v", err)
continue
@ -3624,7 +3626,7 @@ func TestChannelLinkAcceptOverpay(t *testing.T) {
// Even though we sent 2x what was asked for, Carol should still have
// accepted the payment and marked it as settled.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -3819,10 +3821,9 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
hodlFlags []hodl.Flag) (ChannelLink, chan time.Time, func(), error) {
var (
invoiceRegistry = newMockRegistry()
decoder = newMockIteratorDecoder()
obfuscator = NewMockObfuscator()
alicePeer = &mockPeer{
decoder = newMockIteratorDecoder()
obfuscator = NewMockObfuscator()
alicePeer = &mockPeer{
sentMsgs: make(chan lnwire.Message, 2000),
quit: make(chan struct{}),
}
@ -3833,6 +3834,8 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
TimeLockDelta: 6,
}
invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta)
pCache = &mockPreimageCache{
// hash -> preimage
preimageMap: make(map[[32]byte][]byte),

@ -149,7 +149,7 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error)
}
func newMockServer(t testing.TB, name string, startingHeight uint32,
db *channeldb.DB) (*mockServer, error) {
db *channeldb.DB, defaultDelta uint32) (*mockServer, error) {
var id [33]byte
h := sha256.Sum256([]byte(name))
@ -166,7 +166,7 @@ func newMockServer(t testing.TB, name string, startingHeight uint32,
name: name,
messages: make(chan lnwire.Message, 3000),
quit: make(chan struct{}),
registry: newMockRegistry(),
registry: newMockRegistry(defaultDelta),
htlcSwitch: htlcSwitch,
interceptorFuncs: make([]messageInterceptor, 0),
}, nil
@ -648,25 +648,29 @@ var _ ChannelLink = (*mockChannelLink)(nil)
type mockInvoiceRegistry struct {
sync.Mutex
invoices map[chainhash.Hash]channeldb.Invoice
invoices map[chainhash.Hash]channeldb.Invoice
finalDelta uint32
}
func newMockRegistry() *mockInvoiceRegistry {
func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {
return &mockInvoiceRegistry{
invoices: make(map[chainhash.Hash]channeldb.Invoice),
finalDelta: minDelta,
invoices: make(map[chainhash.Hash]channeldb.Invoice),
}
}
func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, error) {
func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, uint32, error) {
i.Lock()
defer i.Unlock()
invoice, ok := i.invoices[rHash]
if !ok {
return channeldb.Invoice{}, fmt.Errorf("can't find mock invoice: %x", rHash[:])
return channeldb.Invoice{}, 0, fmt.Errorf("can't find mock "+
"invoice: %x", rHash[:])
}
return invoice, nil
return invoice, i.finalDelta, nil
}
func (i *mockInvoiceRegistry) SettleInvoice(rhash chainhash.Hash) error {

@ -30,7 +30,7 @@ func genPreimage() ([32]byte, error) {
func TestSwitchSendPending(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
@ -125,11 +125,11 @@ func TestSwitchSendPending(t *testing.T) {
func TestSwitchForward(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -230,11 +230,11 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -421,11 +421,11 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -615,11 +615,11 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -778,11 +778,11 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -936,11 +936,11 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1167,11 +1167,11 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
var packet *htlcPacket
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1237,7 +1237,7 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) {
// We'll create a single link for this test, marking it as being unable
// to forward form the get go.
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
@ -1289,11 +1289,11 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) {
func TestSwitchCancel(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1402,11 +1402,11 @@ func TestSwitchAddSamePayment(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1561,7 +1561,7 @@ func TestSwitchAddSamePayment(t *testing.T) {
func TestSwitchSendPayment(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}

@ -844,16 +844,24 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
bobDb := firstBobChannel.State().Db
carolDb := carolChannel.State().Db
defaultDelta := uint32(6)
// Create three peers/servers.
aliceServer, err := newMockServer(t, "alice", startingHeight, aliceDb)
aliceServer, err := newMockServer(
t, "alice", startingHeight, aliceDb, defaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobServer, err := newMockServer(t, "bob", startingHeight, bobDb)
bobServer, err := newMockServer(
t, "bob", startingHeight, bobDb, defaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
carolServer, err := newMockServer(t, "carol", startingHeight, carolDb)
carolServer, err := newMockServer(
t, "carol", startingHeight, carolDb, defaultDelta,
)
if err != nil {
t.Fatalf("unable to create carol server: %v", err)
}
@ -883,7 +891,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
globalPolicy := ForwardingPolicy{
MinHTLC: lnwire.NewMSatFromSatoshis(5),
BaseFee: lnwire.NewMSatFromSatoshis(1),
TimeLockDelta: 6,
TimeLockDelta: defaultDelta,
}
obfuscator := NewMockObfuscator()