Merge pull request #1477 from Roasbeef/proper-exit-hop-cltv-fix

channeldb+htlcswitch: parse out min cltv delta from payreq rather than modifying the db
This commit is contained in:
Olaoluwa Osuntokun 2018-06-29 18:11:19 -07:00 committed by GitHub
commit 9205720bea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 96 additions and 88 deletions

@ -3,7 +3,6 @@ package channeldb
import (
"crypto/rand"
"crypto/sha256"
prand "math/rand"
"reflect"
"testing"
"time"
@ -25,7 +24,6 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
Terms: ContractTerm{
PaymentPreimage: pre,
Value: value,
FinalCltvDelta: uint16(prand.Int31()),
},
}
i.Memo = []byte("memo")
@ -68,7 +66,6 @@ func TestInvoiceWorkflow(t *testing.T) {
fakeInvoice.PaymentRequest = []byte("")
copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:])
fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000)
fakeInvoice.Terms.FinalCltvDelta = uint16(prand.Int31())
// Add the invoice to the database, this should succeed as there aren't
// any existing invoices within the database with the same payment

@ -68,13 +68,6 @@ type ContractTerm struct {
// Settled indicates if this particular contract term has been fully
// settled by the payer.
Settled bool
// FinalCltvDelta is the lower bound of a delta from the current height
// that the HTLC that wishes to settle this invoice MUST carry. This
// allows the receiver to specify the time window that should be
// available for them to sweep the HTLC on-chain if that becomes
// necessary.
FinalCltvDelta uint16
}
// Invoice is a payment invoice generated by a payee in order to request
@ -86,7 +79,8 @@ type ContractTerm struct {
// invoices are never deleted from the database, instead a bit is toggled
// denoting the invoice has been fully settled. Within the database, all
// invoices must have a unique payment hash which is generated by taking the
// sha256 of the payment preimage.
// sha256 of the payment
// preimage.
type Invoice struct {
// Memo is an optional memo to be stored along side an invoice. The
// memo may contain further details pertaining to the invoice itself,
@ -367,7 +361,7 @@ func serializeInvoice(w io.Writer, i *Invoice) error {
return err
}
return binary.Write(w, byteOrder, i.Terms.FinalCltvDelta)
return nil
}
func fetchInvoice(invoiceNum []byte, invoices *bolt.Bucket) (*Invoice, error) {
@ -429,18 +423,6 @@ func deserializeInvoice(r io.Reader) (*Invoice, error) {
return nil, err
}
// Before we return with the current invoice, we'll check to see if
// there's still enough space in the buffer to read out the final ctlv
// delta. We'll get an EOF error if there isn't any thing else
// lingering in the buffer.
err = binary.Read(r, byteOrder, &invoice.Terms.FinalCltvDelta)
if err != nil && err != io.EOF {
// If we got a non-eof error, then we know there's an actually
// issue. Otherwise, it may have been the case that this
// summary didn't have the set of optional fields.
return nil, err
}
return invoice, nil
}

@ -11,8 +11,10 @@ import (
// which may search, lookup and settle invoices.
type InvoiceDatabase interface {
// LookupInvoice attempts to look up an invoice according to its 32
// byte payment hash.
LookupInvoice(chainhash.Hash) (channeldb.Invoice, error)
// byte payment hash. This method should also reutrn the min final CLTV
// delta for this invoice. We'll use this to ensure that the HTLC
// extended to us gives us enough time to settle as we prescribe.
LookupInvoice(chainhash.Hash) (channeldb.Invoice, uint32, error)
// SettleInvoice attempts to mark an invoice corresponding to the
// passed payment hash as fully settled.

@ -2167,7 +2167,9 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
// we attempt to see if we have an invoice locally
// which'll allow us to settle this htlc.
invoiceHash := chainhash.Hash(pd.RHash)
invoice, err := l.cfg.Registry.LookupInvoice(invoiceHash)
invoice, minCltvDelta, err := l.cfg.Registry.LookupInvoice(
invoiceHash,
)
if err != nil {
log.Errorf("unable to query invoice registry: "+
" %v", err)
@ -2258,7 +2260,6 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
// We'll also ensure that our time-lock value has been
// computed correctly.
minCltvDelta := uint32(invoice.Terms.FinalCltvDelta)
expectedHeight := heightNow + minCltvDelta
switch {

@ -227,7 +227,7 @@ func TestChannelLinkSingleHopPayment(t *testing.T) {
// Check that alice invoice was settled and bandwidth of HTLC
// links was changed.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -452,7 +452,7 @@ func TestChannelLinkMultiHopPayment(t *testing.T) {
// Check that Carol invoice was settled and bandwidth of HTLC
// links were changed.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -793,7 +793,7 @@ func TestUpdateForwardingPolicy(t *testing.T) {
// Carol's invoice should now be shown as settled as the payment
// succeeded.
invoice, err := n.carolServer.registry.LookupInvoice(payResp)
invoice, _, err := n.carolServer.registry.LookupInvoice(payResp)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -909,7 +909,7 @@ func TestChannelLinkMultiHopInsufficientPayment(t *testing.T) {
// Check that alice invoice wasn't settled and bandwidth of htlc
// links hasn't been changed.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -1056,7 +1056,9 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) {
htlcAmt, totalTimelock, hops := generateHops(amount, testStartingHeight,
n.firstBobChannelLink, n.carolChannelLink)
daveServer, err := newMockServer(t, "dave", testStartingHeight, nil)
daveServer, err := newMockServer(
t, "dave", testStartingHeight, nil, n.globalPolicy.TimeLockDelta,
)
if err != nil {
t.Fatalf("unable to init dave's server: %v", err)
}
@ -1077,7 +1079,7 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) {
// Check that alice invoice wasn't settled and bandwidth of htlc
// links hasn't been changed.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -1165,7 +1167,7 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) {
// Check that alice invoice wasn't settled and bandwidth of htlc
// links hasn't been changed.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -1442,7 +1444,6 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
}
var (
invoiceRegistry = newMockRegistry()
decoder = newMockIteratorDecoder()
obfuscator = NewMockObfuscator()
alicePeer = &mockPeer{
@ -1454,6 +1455,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
BaseFee: lnwire.NewMSatFromSatoshis(1),
TimeLockDelta: 6,
}
invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta)
)
pCache := &mockPreimageCache{
@ -3257,7 +3259,7 @@ func TestChannelRetransmission(t *testing.T) {
// Check that alice invoice wasn't settled and
// bandwidth of htlc links hasn't been changed.
invoice, err = receiver.registry.LookupInvoice(rhash)
invoice, _, err = receiver.registry.LookupInvoice(rhash)
if err != nil {
err = errors.Errorf("unable to get invoice: %v", err)
continue
@ -3624,7 +3626,7 @@ func TestChannelLinkAcceptOverpay(t *testing.T) {
// Even though we sent 2x what was asked for, Carol should still have
// accepted the payment and marked it as settled.
invoice, err := receiver.registry.LookupInvoice(rhash)
invoice, _, err := receiver.registry.LookupInvoice(rhash)
if err != nil {
t.Fatalf("unable to get invoice: %v", err)
}
@ -3819,7 +3821,6 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
hodlFlags []hodl.Flag) (ChannelLink, chan time.Time, func(), error) {
var (
invoiceRegistry = newMockRegistry()
decoder = newMockIteratorDecoder()
obfuscator = NewMockObfuscator()
alicePeer = &mockPeer{
@ -3833,6 +3834,8 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
TimeLockDelta: 6,
}
invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta)
pCache = &mockPreimageCache{
// hash -> preimage
preimageMap: make(map[[32]byte][]byte),

@ -149,7 +149,7 @@ func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error)
}
func newMockServer(t testing.TB, name string, startingHeight uint32,
db *channeldb.DB) (*mockServer, error) {
db *channeldb.DB, defaultDelta uint32) (*mockServer, error) {
var id [33]byte
h := sha256.Sum256([]byte(name))
@ -166,7 +166,7 @@ func newMockServer(t testing.TB, name string, startingHeight uint32,
name: name,
messages: make(chan lnwire.Message, 3000),
quit: make(chan struct{}),
registry: newMockRegistry(),
registry: newMockRegistry(defaultDelta),
htlcSwitch: htlcSwitch,
interceptorFuncs: make([]messageInterceptor, 0),
}, nil
@ -648,25 +648,29 @@ var _ ChannelLink = (*mockChannelLink)(nil)
type mockInvoiceRegistry struct {
sync.Mutex
invoices map[chainhash.Hash]channeldb.Invoice
finalDelta uint32
}
func newMockRegistry() *mockInvoiceRegistry {
func newMockRegistry(minDelta uint32) *mockInvoiceRegistry {
return &mockInvoiceRegistry{
finalDelta: minDelta,
invoices: make(map[chainhash.Hash]channeldb.Invoice),
}
}
func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, error) {
func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, uint32, error) {
i.Lock()
defer i.Unlock()
invoice, ok := i.invoices[rHash]
if !ok {
return channeldb.Invoice{}, fmt.Errorf("can't find mock invoice: %x", rHash[:])
return channeldb.Invoice{}, 0, fmt.Errorf("can't find mock "+
"invoice: %x", rHash[:])
}
return invoice, nil
return invoice, i.finalDelta, nil
}
func (i *mockInvoiceRegistry) SettleInvoice(rhash chainhash.Hash) error {

@ -30,7 +30,7 @@ func genPreimage() ([32]byte, error) {
func TestSwitchSendPending(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
@ -125,11 +125,11 @@ func TestSwitchSendPending(t *testing.T) {
func TestSwitchForward(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -230,11 +230,11 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -421,11 +421,11 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -615,11 +615,11 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -778,11 +778,11 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -936,11 +936,11 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1167,11 +1167,11 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) {
var packet *htlcPacket
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1237,7 +1237,7 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) {
// We'll create a single link for this test, marking it as being unable
// to forward form the get go.
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
@ -1289,11 +1289,11 @@ func TestSkipIneligibleLinksLocalForward(t *testing.T) {
func TestSwitchCancel(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1402,11 +1402,11 @@ func TestSwitchAddSamePayment(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil)
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1561,7 +1561,7 @@ func TestSwitchAddSamePayment(t *testing.T) {
func TestSwitchSendPayment(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil)
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}

@ -844,16 +844,24 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
bobDb := firstBobChannel.State().Db
carolDb := carolChannel.State().Db
defaultDelta := uint32(6)
// Create three peers/servers.
aliceServer, err := newMockServer(t, "alice", startingHeight, aliceDb)
aliceServer, err := newMockServer(
t, "alice", startingHeight, aliceDb, defaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobServer, err := newMockServer(t, "bob", startingHeight, bobDb)
bobServer, err := newMockServer(
t, "bob", startingHeight, bobDb, defaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
carolServer, err := newMockServer(t, "carol", startingHeight, carolDb)
carolServer, err := newMockServer(
t, "carol", startingHeight, carolDb, defaultDelta,
)
if err != nil {
t.Fatalf("unable to create carol server: %v", err)
}
@ -883,7 +891,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
globalPolicy := ForwardingPolicy{
MinHTLC: lnwire.NewMSatFromSatoshis(5),
BaseFee: lnwire.NewMSatFromSatoshis(1),
TimeLockDelta: 6,
TimeLockDelta: defaultDelta,
}
obfuscator := NewMockObfuscator()

@ -9,6 +9,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/zpay32"
"github.com/roasbeef/btcd/chaincfg/chainhash"
"github.com/roasbeef/btcutil"
)
@ -95,10 +96,14 @@ func (i *invoiceRegistry) AddInvoice(invoice *channeldb.Invoice) error {
//go i.notifyClients(invoice, false)
}
// lookupInvoice looks up an invoice by its payment hash (R-Hash), if found
// then we're able to pull the funds pending within an HTLC.
// LookupInvoice looks up an invoice by its payment hash (R-Hash), if found
// then we're able to pull the funds pending within an HTLC. We'll also return
// what the expected min final CLTV delta is, pre-parsed from the payment
// request. This may be used by callers to determine if an HTLC is well formed
// according to the cltv delta.
//
// TODO(roasbeef): ignore if settled?
func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, error) {
func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, uint32, error) {
// First check the in-memory debug invoice index to see if this is an
// existing invoice added for debugging.
i.RLock()
@ -107,17 +112,24 @@ func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice
// If found, then simply return the invoice directly.
if ok {
return *invoice, nil
return *invoice, 0, nil
}
// Otherwise, we'll check the database to see if there's an existing
// matching invoice.
invoice, err := i.cdb.LookupInvoice(rHash)
if err != nil {
return channeldb.Invoice{}, err
return channeldb.Invoice{}, 0, err
}
return *invoice, nil
payReq, err := zpay32.Decode(
string(invoice.PaymentRequest), activeNetParams.Params,
)
if err != nil {
return channeldb.Invoice{}, 0, err
}
return *invoice, uint32(payReq.MinFinalCLTVExpiry()), nil
}
// SettleInvoice attempts to mark an invoice as settled. If the invoice is a

@ -2661,7 +2661,6 @@ func (r *rpcServer) AddInvoice(ctx context.Context,
PaymentRequest: []byte(payReqString),
Terms: channeldb.ContractTerm{
Value: amtMSat,
FinalCltvDelta: uint16(payReq.MinFinalCLTVExpiry()),
},
}
copy(i.Terms.PaymentPreimage[:], paymentPreimage[:])
@ -2800,7 +2799,7 @@ func (r *rpcServer) LookupInvoice(ctx context.Context,
rpcsLog.Tracef("[lookupinvoice] searching for invoice %x", payHash[:])
invoice, err := r.server.invoices.LookupInvoice(payHash)
invoice, _, err := r.server.invoices.LookupInvoice(payHash)
if err != nil {
return nil, err
}

@ -73,7 +73,7 @@ func (p *preimageBeacon) LookupPreimage(payHash []byte) ([]byte, bool) {
// the preimage as it's on that we created ourselves.
var invoiceKey chainhash.Hash
copy(invoiceKey[:], payHash)
invoice, err := p.invoices.LookupInvoice(invoiceKey)
invoice, _, err := p.invoices.LookupInvoice(invoiceKey)
switch {
case err == channeldb.ErrInvoiceNotFound:
// If we get this error, then it simply means that this invoice