diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index e5cf4cd8..196dad4c 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -11,8 +11,10 @@ import ( // which may search, lookup and settle invoices. type InvoiceDatabase interface { // LookupInvoice attempts to look up an invoice according to its 32 - // byte payment hash. - LookupInvoice(chainhash.Hash) (channeldb.Invoice, error) + // byte payment hash. This method should also reutrn the min final CLTV + // delta for this invoice. We'll use this to ensure that the HTLC + // extended to us gives us enough time to settle as we prescribe. + LookupInvoice(chainhash.Hash) (channeldb.Invoice, uint32, error) // SettleInvoice attempts to mark an invoice corresponding to the // passed payment hash as fully settled. diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 0336ae34..b60b43f1 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2167,7 +2167,9 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, // we attempt to see if we have an invoice locally // which'll allow us to settle this htlc. invoiceHash := chainhash.Hash(pd.RHash) - invoice, err := l.cfg.Registry.LookupInvoice(invoiceHash) + invoice, minCltvDelta, err := l.cfg.Registry.LookupInvoice( + invoiceHash, + ) if err != nil { log.Errorf("unable to query invoice registry: "+ " %v", err) @@ -2258,7 +2260,6 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, // We'll also ensure that our time-lock value has been // computed correctly. - minCltvDelta := uint32(invoice.Terms.FinalCltvDelta) expectedHeight := heightNow + minCltvDelta switch { diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 2cfba17b..197d1d45 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -227,7 +227,7 @@ func TestChannelLinkSingleHopPayment(t *testing.T) { // Check that alice invoice was settled and bandwidth of HTLC // links was changed. - invoice, err := receiver.registry.LookupInvoice(rhash) + invoice, _, err := receiver.registry.LookupInvoice(rhash) if err != nil { t.Fatalf("unable to get invoice: %v", err) } @@ -452,7 +452,7 @@ func TestChannelLinkMultiHopPayment(t *testing.T) { // Check that Carol invoice was settled and bandwidth of HTLC // links were changed. - invoice, err := receiver.registry.LookupInvoice(rhash) + invoice, _, err := receiver.registry.LookupInvoice(rhash) if err != nil { t.Fatalf("unable to get invoice: %v", err) } @@ -793,7 +793,7 @@ func TestUpdateForwardingPolicy(t *testing.T) { // Carol's invoice should now be shown as settled as the payment // succeeded. - invoice, err := n.carolServer.registry.LookupInvoice(payResp) + invoice, _, err := n.carolServer.registry.LookupInvoice(payResp) if err != nil { t.Fatalf("unable to get invoice: %v", err) } @@ -909,7 +909,7 @@ func TestChannelLinkMultiHopInsufficientPayment(t *testing.T) { // Check that alice invoice wasn't settled and bandwidth of htlc // links hasn't been changed. - invoice, err := receiver.registry.LookupInvoice(rhash) + invoice, _, err := receiver.registry.LookupInvoice(rhash) if err != nil { t.Fatalf("unable to get invoice: %v", err) } @@ -1056,7 +1056,9 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) { htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight, n.firstBobChannelLink, n.carolChannelLink) - daveServer, err := newMockServer(t, "dave", testStartingHeight, nil) + daveServer, err := newMockServer( + t, "dave", testStartingHeight, nil, n.globalPolicy.TimeLockDelta, + ) if err != nil { t.Fatalf("unable to init dave's server: %v", err) } @@ -1077,7 +1079,7 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) { // Check that alice invoice wasn't settled and bandwidth of htlc // links hasn't been changed. - invoice, err := receiver.registry.LookupInvoice(rhash) + invoice, _, err := receiver.registry.LookupInvoice(rhash) if err != nil { t.Fatalf("unable to get invoice: %v", err) } @@ -1165,7 +1167,7 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) { // Check that alice invoice wasn't settled and bandwidth of htlc // links hasn't been changed. - invoice, err := receiver.registry.LookupInvoice(rhash) + invoice, _, err := receiver.registry.LookupInvoice(rhash) if err != nil { t.Fatalf("unable to get invoice: %v", err) } @@ -1442,10 +1444,9 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( } var ( - invoiceRegistry = newMockRegistry() - decoder = newMockIteratorDecoder() - obfuscator = NewMockObfuscator() - alicePeer = &mockPeer{ + decoder = newMockIteratorDecoder() + obfuscator = NewMockObfuscator() + alicePeer = &mockPeer{ sentMsgs: make(chan lnwire.Message, 2000), quit: make(chan struct{}), } @@ -1454,6 +1455,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( BaseFee: lnwire.NewMSatFromSatoshis(1), TimeLockDelta: 6, } + invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta) ) pCache := &mockPreimageCache{ @@ -3257,7 +3259,7 @@ func TestChannelRetransmission(t *testing.T) { // Check that alice invoice wasn't settled and // bandwidth of htlc links hasn't been changed. - invoice, err = receiver.registry.LookupInvoice(rhash) + invoice, _, err = receiver.registry.LookupInvoice(rhash) if err != nil { err = errors.Errorf("unable to get invoice: %v", err) continue @@ -3624,7 +3626,7 @@ func TestChannelLinkAcceptOverpay(t *testing.T) { // Even though we sent 2x what was asked for, Carol should still have // accepted the payment and marked it as settled. - invoice, err := receiver.registry.LookupInvoice(rhash) + invoice, _, err := receiver.registry.LookupInvoice(rhash) if err != nil { t.Fatalf("unable to get invoice: %v", err) } @@ -3819,10 +3821,9 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch, hodlFlags []hodl.Flag) (ChannelLink, chan time.Time, func(), error) { var ( - invoiceRegistry = newMockRegistry() - decoder = newMockIteratorDecoder() - obfuscator = NewMockObfuscator() - alicePeer = &mockPeer{ + decoder = newMockIteratorDecoder() + obfuscator = NewMockObfuscator() + alicePeer = &mockPeer{ sentMsgs: make(chan lnwire.Message, 2000), quit: make(chan struct{}), } @@ -3833,6 +3834,8 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch, TimeLockDelta: 6, } + invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta) + pCache = &mockPreimageCache{ // hash -> preimage preimageMap: make(map[[32]byte][]byte), diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index b96fa1e4..e0016188 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -149,7 +149,7 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) } func newMockServer(t testing.TB, name string, startingHeight uint32, - db *channeldb.DB) (*mockServer, error) { + db *channeldb.DB, defaultDelta uint32) (*mockServer, error) { var id [33]byte h := sha256.Sum256([]byte(name)) @@ -166,7 +166,7 @@ func newMockServer(t testing.TB, name string, startingHeight uint32, name: name, messages: make(chan lnwire.Message, 3000), quit: make(chan struct{}), - registry: newMockRegistry(), + registry: newMockRegistry(defaultDelta), htlcSwitch: htlcSwitch, interceptorFuncs: make([]messageInterceptor, 0), }, nil @@ -648,25 +648,29 @@ var _ ChannelLink = (*mockChannelLink)(nil) type mockInvoiceRegistry struct { sync.Mutex - invoices map[chainhash.Hash]channeldb.Invoice + + invoices map[chainhash.Hash]channeldb.Invoice + finalDelta uint32 } -func newMockRegistry() *mockInvoiceRegistry { +func newMockRegistry(minDelta uint32) *mockInvoiceRegistry { return &mockInvoiceRegistry{ - invoices: make(map[chainhash.Hash]channeldb.Invoice), + finalDelta: minDelta, + invoices: make(map[chainhash.Hash]channeldb.Invoice), } } -func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, error) { +func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, uint32, error) { i.Lock() defer i.Unlock() invoice, ok := i.invoices[rHash] if !ok { - return channeldb.Invoice{}, fmt.Errorf("can't find mock invoice: %x", rHash[:]) + return channeldb.Invoice{}, 0, fmt.Errorf("can't find mock "+ + "invoice: %x", rHash[:]) } - return invoice, nil + return invoice, i.finalDelta, nil } func (i *mockInvoiceRegistry) SettleInvoice(rhash chainhash.Hash) error { diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index ae0a2586..955d89ba 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -30,7 +30,7 @@ func genPreimage() ([32]byte, error) { func TestSwitchSendPending(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } @@ -125,11 +125,11 @@ func TestSwitchSendPending(t *testing.T) { func TestSwitchForward(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -230,11 +230,11 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -421,11 +421,11 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -615,11 +615,11 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -778,11 +778,11 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -936,11 +936,11 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -1167,11 +1167,11 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { var packet *htlcPacket - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -1237,7 +1237,7 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { // We'll create a single link for this test, marking it as being unable // to forward form the get go. - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } @@ -1289,11 +1289,11 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) { func TestSwitchCancel(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -1402,11 +1402,11 @@ func TestSwitchAddSamePayment(t *testing.T) { chanID1, chanID2, aliceChanID, bobChanID := genIDs() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil) + bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create bob server: %v", err) } @@ -1561,7 +1561,7 @@ func TestSwitchAddSamePayment(t *testing.T) { func TestSwitchSendPayment(t *testing.T) { t.Parallel() - alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil) + alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6) if err != nil { t.Fatalf("unable to create alice server: %v", err) } diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 8da96ece..56af3f90 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -844,16 +844,24 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, bobDb := firstBobChannel.State().Db carolDb := carolChannel.State().Db + defaultDelta := uint32(6) + // Create three peers/servers. - aliceServer, err := newMockServer(t, "alice", startingHeight, aliceDb) + aliceServer, err := newMockServer( + t, "alice", startingHeight, aliceDb, defaultDelta, + ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } - bobServer, err := newMockServer(t, "bob", startingHeight, bobDb) + bobServer, err := newMockServer( + t, "bob", startingHeight, bobDb, defaultDelta, + ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } - carolServer, err := newMockServer(t, "carol", startingHeight, carolDb) + carolServer, err := newMockServer( + t, "carol", startingHeight, carolDb, defaultDelta, + ) if err != nil { t.Fatalf("unable to create carol server: %v", err) } @@ -883,7 +891,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, globalPolicy := ForwardingPolicy{ MinHTLC: lnwire.NewMSatFromSatoshis(5), BaseFee: lnwire.NewMSatFromSatoshis(1), - TimeLockDelta: 6, + TimeLockDelta: defaultDelta, } obfuscator := NewMockObfuscator()