channeldb+invoices: use payment addr as primary index

This commit is contained in:
Conner Fromknecht 2020-05-21 15:37:39 -07:00
parent 3522f09a08
commit cbf71b5452
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
7 changed files with 286 additions and 34 deletions

@ -13,6 +13,7 @@ import (
"github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/walletdb"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/channeldb/kvdb"
mig "github.com/lightningnetwork/lnd/channeldb/migration"
"github.com/lightningnetwork/lnd/channeldb/migration12" "github.com/lightningnetwork/lnd/channeldb/migration12"
"github.com/lightningnetwork/lnd/channeldb/migration13" "github.com/lightningnetwork/lnd/channeldb/migration13"
"github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11"
@ -136,6 +137,13 @@ var (
number: 13, number: 13,
migration: migration13.MigrateMPP, migration: migration13.MigrateMPP,
}, },
{
// Initialize payment address index and begin using it
// as the default index, falling back to payment hash
// index.
number: 14,
migration: mig.CreateTLB(payAddrIndexBucket),
},
} }
// Big endian is the preferred byte order, due to cursor scans over // Big endian is the preferred byte order, due to cursor scans over
@ -248,6 +256,7 @@ var topLevelBuckets = [][]byte{
forwardingLogBucket, forwardingLogBucket,
fwdPackagesKey, fwdPackagesKey,
invoiceBucket, invoiceBucket,
payAddrIndexBucket,
nodeInfoBucket, nodeInfoBucket,
nodeBucket, nodeBucket,
edgeBucket, edgeBucket,

@ -43,6 +43,14 @@ var (
// payment hash already exists. // payment hash already exists.
ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists") ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists")
// ErrDuplicatePayAddr is returned when an invoice with the target
// payment addr already exists.
ErrDuplicatePayAddr = fmt.Errorf("invoice with payemnt addr already exists")
// ErrInvRefEquivocation is returned when an InvoiceRef targets
// multiple, distinct invoices.
ErrInvRefEquivocation = errors.New("inv ref matches multiple invoices")
// ErrNoPaymentsCreated is returned when bucket of payments hasn't been // ErrNoPaymentsCreated is returned when bucket of payments hasn't been
// created. // created.
ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments") ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments")

@ -20,16 +20,20 @@ var (
) )
func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
var pre [32]byte var pre, payAddr [32]byte
if _, err := rand.Read(pre[:]); err != nil { if _, err := rand.Read(pre[:]); err != nil {
return nil, err return nil, err
} }
if _, err := rand.Read(payAddr[:]); err != nil {
return nil, err
}
i := &Invoice{ i := &Invoice{
CreationDate: testNow, CreationDate: testNow,
Terms: ContractTerm{ Terms: ContractTerm{
Expiry: 4000, Expiry: 4000,
PaymentPreimage: pre, PaymentPreimage: pre,
PaymentAddr: payAddr,
Value: value, Value: value,
Features: emptyFeatures, Features: emptyFeatures,
}, },
@ -91,9 +95,45 @@ func TestInvoiceIsPending(t *testing.T) {
} }
} }
type invWorkflowTest struct {
name string
queryPayHash bool
queryPayAddr bool
}
var invWorkflowTests = []invWorkflowTest{
{
name: "unknown",
queryPayHash: false,
queryPayAddr: false,
},
{
name: "only payhash known",
queryPayHash: true,
queryPayAddr: false,
},
{
name: "payaddr and payhash known",
queryPayHash: true,
queryPayAddr: true,
},
}
// TestInvoiceWorkflow asserts the basic process of inserting, fetching, and
// updating an invoice. We assert that the flow is successful using when
// querying with various combinations of payment hash and payment address.
func TestInvoiceWorkflow(t *testing.T) { func TestInvoiceWorkflow(t *testing.T) {
t.Parallel() t.Parallel()
for _, test := range invWorkflowTests {
test := test
t.Run(test.name, func(t *testing.T) {
testInvoiceWorkflow(t, test)
})
}
}
func testInvoiceWorkflow(t *testing.T, test invWorkflowTest) {
db, cleanUp, err := makeTestDB() db, cleanUp, err := makeTestDB()
defer cleanUp() defer cleanUp()
if err != nil { if err != nil {
@ -102,23 +142,33 @@ func TestInvoiceWorkflow(t *testing.T) {
// Create a fake invoice which we'll use several times in the tests // Create a fake invoice which we'll use several times in the tests
// below. // below.
fakeInvoice := &Invoice{ fakeInvoice, err := randInvoice(10000)
CreationDate: testNow, if err != nil {
Htlcs: map[CircuitKey]*InvoiceHTLC{}, t.Fatalf("unable to create invoice: %v", err)
} }
fakeInvoice.Memo = []byte("memo") invPayHash := fakeInvoice.Terms.PaymentPreimage.Hash()
fakeInvoice.PaymentRequest = []byte("")
copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:])
fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000)
fakeInvoice.Terms.Features = emptyFeatures
paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() // Select the payment hash and payment address we will use to lookup or
ref := InvoiceRefByHash(paymentHash) // update the invoice for the remainder of the test.
var (
payHash lntypes.Hash
payAddr *[32]byte
ref InvoiceRef
)
switch {
case test.queryPayHash && test.queryPayAddr:
payHash = invPayHash
payAddr = &fakeInvoice.Terms.PaymentAddr
ref = InvoiceRefByHashAndAddr(payHash, *payAddr)
case test.queryPayHash:
payHash = invPayHash
ref = InvoiceRefByHash(payHash)
}
// Add the invoice to the database, this should succeed as there aren't // Add the invoice to the database, this should succeed as there aren't
// any existing invoices within the database with the same payment // any existing invoices within the database with the same payment
// hash. // hash.
if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != nil { if _, err := db.AddInvoice(fakeInvoice, invPayHash); err != nil {
t.Fatalf("unable to find invoice: %v", err) t.Fatalf("unable to find invoice: %v", err)
} }
@ -126,8 +176,11 @@ func TestInvoiceWorkflow(t *testing.T) {
// database. It should be found, and the invoice returned should be // database. It should be found, and the invoice returned should be
// identical to the one created above. // identical to the one created above.
dbInvoice, err := db.LookupInvoice(ref) dbInvoice, err := db.LookupInvoice(ref)
if err != nil { if !test.queryPayAddr && !test.queryPayHash {
t.Fatalf("unable to find invoice: %v", err) if err != ErrInvoiceNotFound {
t.Fatalf("invoice should not exist: %v", err)
}
return
} }
if !reflect.DeepEqual(*fakeInvoice, dbInvoice) { if !reflect.DeepEqual(*fakeInvoice, dbInvoice) {
t.Fatalf("invoice fetched from db doesn't match original %v vs %v", t.Fatalf("invoice fetched from db doesn't match original %v vs %v",
@ -174,7 +227,7 @@ func TestInvoiceWorkflow(t *testing.T) {
// Attempt to insert generated above again, this should fail as // Attempt to insert generated above again, this should fail as
// duplicates are rejected by the processing logic. // duplicates are rejected by the processing logic.
if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != ErrDuplicateInvoice { if _, err := db.AddInvoice(fakeInvoice, payHash); err != ErrDuplicateInvoice {
t.Fatalf("invoice insertion should fail due to duplication, "+ t.Fatalf("invoice insertion should fail due to duplication, "+
"instead %v", err) "instead %v", err)
} }
@ -232,6 +285,70 @@ func TestInvoiceWorkflow(t *testing.T) {
} }
} }
// TestAddDuplicatePayAddr asserts that the payment addresses of inserted
// invoices are unique.
func TestAddDuplicatePayAddr(t *testing.T) {
db, cleanUp, err := makeTestDB()
defer cleanUp()
assert.Nil(t, err)
// Create two invoices with the same payment addr.
invoice1, err := randInvoice(1000)
assert.Nil(t, err)
invoice2, err := randInvoice(20000)
assert.Nil(t, err)
invoice2.Terms.PaymentAddr = invoice1.Terms.PaymentAddr
// First insert should succeed.
inv1Hash := invoice1.Terms.PaymentPreimage.Hash()
_, err = db.AddInvoice(invoice1, inv1Hash)
assert.Nil(t, err)
// Second insert should fail with duplicate payment addr.
inv2Hash := invoice2.Terms.PaymentPreimage.Hash()
_, err = db.AddInvoice(invoice2, inv2Hash)
assert.Equal(t, ErrDuplicatePayAddr, err)
}
// TestInvRefEquivocation asserts that retrieving or updating an invoice using
// an equivocating InvoiceRef results in ErrInvRefEquivocation.
func TestInvRefEquivocation(t *testing.T) {
db, cleanUp, err := makeTestDB()
defer cleanUp()
assert.Nil(t, err)
// Add two random invoices.
invoice1, err := randInvoice(1000)
assert.Nil(t, err)
inv1Hash := invoice1.Terms.PaymentPreimage.Hash()
_, err = db.AddInvoice(invoice1, inv1Hash)
assert.Nil(t, err)
invoice2, err := randInvoice(2000)
assert.Nil(t, err)
inv2Hash := invoice2.Terms.PaymentPreimage.Hash()
_, err = db.AddInvoice(invoice2, inv2Hash)
assert.Nil(t, err)
// Now, query using invoice 1's payment address, but invoice 2's payment
// hash. We expect an error since the invref points to multiple
// invoices.
ref := InvoiceRefByHashAndAddr(inv2Hash, invoice1.Terms.PaymentAddr)
_, err = db.LookupInvoice(ref)
assert.Equal(t, ErrInvRefEquivocation, err)
// The same error should be returned when updating an equivocating
// reference.
nop := func(_ *Invoice) (*InvoiceUpdateDesc, error) {
return nil, nil
}
_, err = db.UpdateInvoice(ref, nop)
assert.Equal(t, ErrInvRefEquivocation, err)
}
// TestInvoiceCancelSingleHtlc tests that a single htlc can be canceled on the // TestInvoiceCancelSingleHtlc tests that a single htlc can be canceled on the
// invoice. // invoice.
func TestInvoiceCancelSingleHtlc(t *testing.T) { func TestInvoiceCancelSingleHtlc(t *testing.T) {
@ -991,9 +1108,17 @@ func TestCustomRecords(t *testing.T) {
// InvoiceRef depending on the constructor used. // InvoiceRef depending on the constructor used.
func TestInvoiceRef(t *testing.T) { func TestInvoiceRef(t *testing.T) {
payHash := lntypes.Hash{0x01} payHash := lntypes.Hash{0x01}
payAddr := [32]byte{0x02}
// An InvoiceRef by hash should return the provided hash and a nil // An InvoiceRef by hash should return the provided hash and a nil
// payment addr. // payment addr.
refByHash := InvoiceRefByHash(payHash) refByHash := InvoiceRefByHash(payHash)
assert.Equal(t, payHash, refByHash.PayHash()) assert.Equal(t, payHash, refByHash.PayHash())
assert.Equal(t, (*[32]byte)(nil), refByHash.PayAddr())
// An InvoiceRef by hash and addr should return the payment hash and
// payment addr passed to the constructor.
refByHashAndAddr := InvoiceRefByHashAndAddr(payHash, payAddr)
assert.Equal(t, payHash, refByHashAndAddr.PayHash())
assert.Equal(t, &payAddr, refByHashAndAddr.PayAddr())
} }

@ -37,6 +37,16 @@ var (
// maps: payHash => invoiceKey // maps: payHash => invoiceKey
invoiceIndexBucket = []byte("paymenthashes") invoiceIndexBucket = []byte("paymenthashes")
// payAddrIndexBucket is the name of the top-level bucket that maps
// payment addresses to their invoice number. This can be used
// to efficiently query or update non-legacy invoices. Note that legacy
// invoices will not be included in this index since they all have the
// same, all-zero payment address, however all newly generated invoices
// will end up in this index.
//
// maps: payAddr => invoiceKey
payAddrIndexBucket = []byte("pay-addr-index")
// numInvoicesKey is the name of key which houses the auto-incrementing // numInvoicesKey is the name of key which houses the auto-incrementing
// invoice ID which is essentially used as a primary key. With each // invoice ID which is essentially used as a primary key. With each
// invoice inserted, the primary key is incremented by one. This key is // invoice inserted, the primary key is incremented by one. This key is
@ -142,12 +152,23 @@ const (
amtPaidType tlv.Type = 13 amtPaidType tlv.Type = 13
) )
// InvoiceRef is an identifier for invoices supporting queries by payment hash. // InvoiceRef is a composite identifier for invoices. Invoices can be referenced
// by various combinations of payment hash and payment addr, in certain contexts
// only some of these are known. An InvoiceRef and its constructors thus
// encapsulate the valid combinations of query parameters that can be supplied
// to LookupInvoice and UpdateInvoice.
type InvoiceRef struct { type InvoiceRef struct {
// payHash is the payment hash of the target invoice. All invoices are // payHash is the payment hash of the target invoice. All invoices are
// currently indexed by payment hash. This value will be used as a // currently indexed by payment hash. This value will be used as a
// fallback when no payment address is known. // fallback when no payment address is known.
payHash lntypes.Hash payHash lntypes.Hash
// payAddr is the payment addr of the target invoice. Newer invoices
// (0.11 and up) are indexed by payment address in addition to payment
// hash, but pre 0.8 invoices do not have one at all. When this value is
// known it will be used as the primary identifier, falling back to
// payHash if no value is known.
payAddr *[32]byte
} }
// InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by // InvoiceRefByHash creates an InvoiceRef that queries for an invoice only by
@ -158,13 +179,39 @@ func InvoiceRefByHash(payHash lntypes.Hash) InvoiceRef {
} }
} }
// InvoiceRefByHashAndAddr creates an InvoiceRef that first queries for an
// invoice by the provided payment address, falling back to the payment hash if
// the payment address is unknown.
func InvoiceRefByHashAndAddr(payHash lntypes.Hash,
payAddr [32]byte) InvoiceRef {
return InvoiceRef{
payHash: payHash,
payAddr: &payAddr,
}
}
// PayHash returns the target invoice's payment hash. // PayHash returns the target invoice's payment hash.
func (r InvoiceRef) PayHash() lntypes.Hash { func (r InvoiceRef) PayHash() lntypes.Hash {
return r.payHash return r.payHash
} }
// PayAddr returns the optional payment address of the target invoice.
//
// NOTE: This value may be nil.
func (r InvoiceRef) PayAddr() *[32]byte {
if r.payAddr != nil {
addr := *r.payAddr
return &addr
}
return nil
}
// String returns a human-readable representation of an InvoiceRef. // String returns a human-readable representation of an InvoiceRef.
func (r InvoiceRef) String() string { func (r InvoiceRef) String() string {
if r.payAddr != nil {
return fmt.Sprintf("(pay_hash=%v, pay_addr=%x)", r.payHash, *r.payAddr)
}
return fmt.Sprintf("(pay_hash=%v)", r.payHash) return fmt.Sprintf("(pay_hash=%v)", r.payHash)
} }
@ -458,6 +505,11 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) (
return ErrDuplicateInvoice return ErrDuplicateInvoice
} }
payAddrIndex := tx.ReadWriteBucket(payAddrIndexBucket)
if payAddrIndex.Get(newInvoice.Terms.PaymentAddr[:]) != nil {
return ErrDuplicatePayAddr
}
// If the current running payment ID counter hasn't yet been // If the current running payment ID counter hasn't yet been
// created, then create it now. // created, then create it now.
var invoiceNum uint32 var invoiceNum uint32
@ -474,8 +526,8 @@ func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) (
} }
newIndex, err := putInvoice( newIndex, err := putInvoice(
invoices, invoiceIndex, addIndex, newInvoice, invoiceNum, invoices, invoiceIndex, payAddrIndex, addIndex,
paymentHash, newInvoice, invoiceNum, paymentHash,
) )
if err != nil { if err != nil {
return err return err
@ -575,11 +627,12 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) {
if invoiceIndex == nil { if invoiceIndex == nil {
return ErrNoInvoicesCreated return ErrNoInvoicesCreated
} }
payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
// Retrieve the invoice number for this invoice using the // Retrieve the invoice number for this invoice using the
// provided invoice reference. // provided invoice reference.
invoiceNum, err := fetchInvoiceNumByRef( invoiceNum, err := fetchInvoiceNumByRef(
invoiceIndex, ref, invoiceIndex, payAddrIndex, ref,
) )
if err != nil { if err != nil {
return err return err
@ -603,18 +656,44 @@ func (d *DB) LookupInvoice(ref InvoiceRef) (Invoice, error) {
} }
// fetchInvoiceNumByRef retrieve the invoice number for the provided invoice // fetchInvoiceNumByRef retrieve the invoice number for the provided invoice
// reference. // reference. The payment address will be treated as the primary key, falling
func fetchInvoiceNumByRef(invoiceIndex kvdb.ReadBucket, // back to the payment hash if nothing is found for the payment address. An
// error is returned if the invoice is not found.
func fetchInvoiceNumByRef(invoiceIndex, payAddrIndex kvdb.ReadBucket,
ref InvoiceRef) ([]byte, error) { ref InvoiceRef) ([]byte, error) {
payHash := ref.PayHash() payHash := ref.PayHash()
payAddr := ref.PayAddr()
invoiceNum := invoiceIndex.Get(payHash[:]) var (
if invoiceNum == nil { invoiceNumByHash = invoiceIndex.Get(payHash[:])
return nil, ErrInvoiceNotFound invoiceNumByAddr []byte
)
if payAddr != nil {
invoiceNumByAddr = payAddrIndex.Get(payAddr[:])
} }
return invoiceNum, nil switch {
// If payment address and payment hash both reference an existing
// invoice, ensure they reference the _same_ invoice.
case invoiceNumByAddr != nil && invoiceNumByHash != nil:
if !bytes.Equal(invoiceNumByAddr, invoiceNumByHash) {
return nil, ErrInvRefEquivocation
}
return invoiceNumByAddr, nil
// If we were only able to reference the invoice by hash, return the
// corresponding invoice number. This can happen when no payment address
// was provided, or if it didn't match anything in our records.
case invoiceNumByHash != nil:
return invoiceNumByHash, nil
// Otherwise we don't know of the target invoice.
default:
return nil, ErrInvoiceNotFound
}
} }
// InvoiceWithPaymentHash is used to store an invoice and its corresponding // InvoiceWithPaymentHash is used to store an invoice and its corresponding
@ -888,11 +967,12 @@ func (d *DB) UpdateInvoice(ref InvoiceRef,
if err != nil { if err != nil {
return err return err
} }
payAddrIndex := tx.ReadBucket(payAddrIndexBucket)
// Retrieve the invoice number for this invoice using the // Retrieve the invoice number for this invoice using the
// provided invoice reference. // provided invoice reference.
invoiceNum, err := fetchInvoiceNumByRef( invoiceNum, err := fetchInvoiceNumByRef(
invoiceIndex, ref, invoiceIndex, payAddrIndex, ref,
) )
if err != nil { if err != nil {
return err return err
@ -971,7 +1051,7 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) {
return settledInvoices, nil return settledInvoices, nil
} }
func putInvoice(invoices, invoiceIndex, addIndex kvdb.RwBucket, func putInvoice(invoices, invoiceIndex, payAddrIndex, addIndex kvdb.RwBucket,
i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) ( i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
uint64, error) { uint64, error) {
@ -996,6 +1076,10 @@ func putInvoice(invoices, invoiceIndex, addIndex kvdb.RwBucket,
if err != nil { if err != nil {
return 0, err return 0, err
} }
err = payAddrIndex.Put(i.Terms.PaymentAddr[:], invoiceKey[:])
if err != nil {
return 0, err
}
// Next, we'll obtain the next add invoice index (sequence // Next, we'll obtain the next add invoice index (sequence
// number), so we can properly place this invoice within this // number), so we can properly place this invoice within this

@ -2,7 +2,7 @@ package htlcswitch
import ( import (
"bytes" "bytes"
"crypto/rand" crand "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
@ -137,7 +137,7 @@ func generateRandomBytes(n int) ([]byte, error) {
// TODO(roasbeef): should use counter in tests (atomic) rather than // TODO(roasbeef): should use counter in tests (atomic) rather than
// this // this
_, err := rand.Read(b[:]) _, err := crand.Read(b)
// Note that Err == nil only if we read len(b) bytes. // Note that Err == nil only if we read len(b) bytes.
if err != nil { if err != nil {
return nil, err return nil, err
@ -547,7 +547,7 @@ func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) {
// invoice which should be added by destination peer. // invoice which should be added by destination peer.
func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi,
timelock uint32, blob [lnwire.OnionPacketSize]byte, timelock uint32, blob [lnwire.OnionPacketSize]byte,
preimage, rhash [32]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC, preimage, rhash, payAddr [32]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC,
uint64, error) { uint64, error) {
// Create the db invoice. Normally the payment requests needs to be set, // Create the db invoice. Normally the payment requests needs to be set,
@ -562,6 +562,7 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi,
FinalCltvDelta: testInvoiceCltvExpiry, FinalCltvDelta: testInvoiceCltvExpiry,
Value: invoiceAmt, Value: invoiceAmt,
PaymentPreimage: preimage, PaymentPreimage: preimage,
PaymentAddr: payAddr,
Features: lnwire.NewFeatureVector( Features: lnwire.NewFeatureVector(
nil, lnwire.Features, nil, lnwire.Features,
), ),
@ -598,8 +599,16 @@ func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32,
copy(preimage[:], r) copy(preimage[:], r)
rhash := sha256.Sum256(preimage[:]) rhash := sha256.Sum256(preimage[:])
var payAddr [sha256.Size]byte
r, err = generateRandomBytes(sha256.Size)
if err != nil {
return nil, nil, 0, err
}
copy(payAddr[:], r)
return generatePaymentWithPreimage( return generatePaymentWithPreimage(
invoiceAmt, htlcAmt, timelock, blob, preimage, rhash, invoiceAmt, htlcAmt, timelock, blob, preimage, rhash, payAddr,
) )
} }
@ -1328,10 +1337,15 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer,
rhash := preimage.Hash() rhash := preimage.Hash()
var payAddr [32]byte
if _, err := crand.Read(payAddr[:]); err != nil {
panic(err)
}
// Generate payment: invoice and htlc. // Generate payment: invoice and htlc.
invoice, htlc, pid, err := generatePaymentWithPreimage( invoice, htlc, pid, err := generatePaymentWithPreimage(
invoiceAmt, htlcAmt, timelock, blob, invoiceAmt, htlcAmt, timelock, blob,
channeldb.UnknownPreimage, rhash, channeldb.UnknownPreimage, rhash, payAddr,
) )
if err != nil { if err != nil {
paymentErr <- err paymentErr <- err

@ -1,6 +1,7 @@
package invoices package invoices
import ( import (
"crypto/rand"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
@ -198,14 +199,20 @@ func newTestInvoice(t *testing.T, preimage lntypes.Preimage,
expiry = time.Hour expiry = time.Hour
} }
var payAddr [32]byte
if _, err := rand.Read(payAddr[:]); err != nil {
t.Fatalf("unable to generate payment addr: %v", err)
}
rawInvoice, err := zpay32.NewInvoice( rawInvoice, err := zpay32.NewInvoice(
testNetParams, testNetParams,
preimage.Hash(), preimage.Hash(),
timestamp, timestamp,
zpay32.Amount(testInvoiceAmount), zpay32.Amount(testInvoiceAmount),
zpay32.Description(testInvoiceDescription), zpay32.Description(testInvoiceDescription),
zpay32.Expiry(expiry)) zpay32.Expiry(expiry),
zpay32.PaymentAddr(payAddr),
)
if err != nil { if err != nil {
t.Fatalf("Error while creating new invoice: %v", err) t.Fatalf("Error while creating new invoice: %v", err)
} }
@ -219,6 +226,7 @@ func newTestInvoice(t *testing.T, preimage lntypes.Preimage,
return &channeldb.Invoice{ return &channeldb.Invoice{
Terms: channeldb.ContractTerm{ Terms: channeldb.ContractTerm{
PaymentPreimage: preimage, PaymentPreimage: preimage,
PaymentAddr: payAddr,
Value: testInvoiceAmount, Value: testInvoiceAmount,
Expiry: expiry, Expiry: expiry,
Features: testFeatures, Features: testFeatures,

@ -25,6 +25,10 @@ type invoiceUpdateCtx struct {
// invoiceRef returns an identifier that can be used to lookup or update the // invoiceRef returns an identifier that can be used to lookup or update the
// invoice this HTLC is targeting. // invoice this HTLC is targeting.
func (i *invoiceUpdateCtx) invoiceRef() channeldb.InvoiceRef { func (i *invoiceUpdateCtx) invoiceRef() channeldb.InvoiceRef {
if i.mpp != nil {
payAddr := i.mpp.PaymentAddr()
return channeldb.InvoiceRefByHashAndAddr(i.hash, payAddr)
}
return channeldb.InvoiceRefByHash(i.hash) return channeldb.InvoiceRefByHash(i.hash)
} }