Merge pull request #4334 from joostjager/hold-keysend-part1

invoices: add explicit hodl invoice flag
This commit is contained in:
Joost Jager 2020-06-02 10:59:47 +02:00 committed by GitHub
commit 9f32942a90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 153 additions and 93 deletions

@ -20,7 +20,10 @@ var (
) )
func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
var pre, payAddr [32]byte var (
pre lntypes.Preimage
payAddr [32]byte
)
if _, err := rand.Read(pre[:]); err != nil { if _, err := rand.Read(pre[:]); err != nil {
return nil, err return nil, err
} }
@ -32,7 +35,7 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
CreationDate: testNow, CreationDate: testNow,
Terms: ContractTerm{ Terms: ContractTerm{
Expiry: 4000, Expiry: 4000,
PaymentPreimage: pre, PaymentPreimage: &pre,
PaymentAddr: payAddr, PaymentAddr: payAddr,
Value: value, Value: value,
Features: emptyFeatures, Features: emptyFeatures,
@ -360,13 +363,18 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) {
t.Fatalf("unable to make test db: %v", err) t.Fatalf("unable to make test db: %v", err)
} }
preimage := lntypes.Preimage{1}
paymentHash := preimage.Hash()
testInvoice := &Invoice{ testInvoice := &Invoice{
Htlcs: map[CircuitKey]*InvoiceHTLC{}, Htlcs: map[CircuitKey]*InvoiceHTLC{},
Terms: ContractTerm{
Value: lnwire.NewMSatFromSatoshis(10000),
Features: emptyFeatures,
PaymentPreimage: &preimage,
},
} }
testInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000)
testInvoice.Terms.Features = emptyFeatures
var paymentHash lntypes.Hash
if _, err := db.AddInvoice(testInvoice, paymentHash); err != nil { if _, err := db.AddInvoice(testInvoice, paymentHash); err != nil {
t.Fatalf("unable to find invoice: %v", err) t.Fatalf("unable to find invoice: %v", err)
} }
@ -1059,15 +1067,20 @@ func TestCustomRecords(t *testing.T) {
t.Fatalf("unable to make test db: %v", err) t.Fatalf("unable to make test db: %v", err)
} }
preimage := lntypes.Preimage{1}
paymentHash := preimage.Hash()
testInvoice := &Invoice{ testInvoice := &Invoice{
Htlcs: map[CircuitKey]*InvoiceHTLC{}, Htlcs: map[CircuitKey]*InvoiceHTLC{},
Terms: ContractTerm{
Value: lnwire.NewMSatFromSatoshis(10000),
Features: emptyFeatures,
PaymentPreimage: &preimage,
},
} }
testInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000)
testInvoice.Terms.Features = emptyFeatures
var paymentHash lntypes.Hash
if _, err := db.AddInvoice(testInvoice, paymentHash); err != nil { if _, err := db.AddInvoice(testInvoice, paymentHash); err != nil {
t.Fatalf("unable to find invoice: %v", err) t.Fatalf("unable to add invoice: %v", err)
} }
// Accept an htlc with custom records on this invoice. // Accept an htlc with custom records on this invoice.

@ -17,9 +17,9 @@ import (
) )
var ( var (
// UnknownPreimage is an all-zeroes preimage that indicates that the // unknownPreimage is an all-zeroes preimage that indicates that the
// preimage for this invoice is not yet known. // preimage for this invoice is not yet known.
UnknownPreimage lntypes.Preimage unknownPreimage lntypes.Preimage
// invoiceBucket is the name of the bucket within the database that // invoiceBucket is the name of the bucket within the database that
// stores all data related to invoices no matter their final state. // stores all data related to invoices no matter their final state.
@ -150,6 +150,7 @@ const (
featuresType tlv.Type = 11 featuresType tlv.Type = 11
invStateType tlv.Type = 12 invStateType tlv.Type = 12
amtPaidType tlv.Type = 13 amtPaidType tlv.Type = 13
hodlInvoiceType tlv.Type = 14
) )
// InvoiceRef is a composite identifier for invoices. Invoices can be referenced // InvoiceRef is a composite identifier for invoices. Invoices can be referenced
@ -261,8 +262,8 @@ type ContractTerm struct {
// PaymentPreimage is the preimage which is to be revealed in the // PaymentPreimage is the preimage which is to be revealed in the
// occasion that an HTLC paying to the hash of this preimage is // occasion that an HTLC paying to the hash of this preimage is
// extended. // extended. Set to nil if the preimage isn't known yet.
PaymentPreimage lntypes.Preimage PaymentPreimage *lntypes.Preimage
// Value is the expected amount of milli-satoshis to be paid to an HTLC // Value is the expected amount of milli-satoshis to be paid to an HTLC
// which can be satisfied by the above preimage. // which can be satisfied by the above preimage.
@ -346,6 +347,10 @@ type Invoice struct {
// Htlcs records all htlcs that paid to this invoice. Some of these // Htlcs records all htlcs that paid to this invoice. Some of these
// htlcs may have been marked as canceled. // htlcs may have been marked as canceled.
Htlcs map[CircuitKey]*InvoiceHTLC Htlcs map[CircuitKey]*InvoiceHTLC
// HodlInvoice indicates whether the invoice should be held in the
// Accepted state or be settled right away.
HodlInvoice bool
} }
// HtlcState defines the states an htlc paying to an invoice can be in. // HtlcState defines the states an htlc paying to an invoice can be in.
@ -439,14 +444,19 @@ type InvoiceStateUpdateDesc struct {
NewState ContractState NewState ContractState
// Preimage must be set to the preimage when NewState is settled. // Preimage must be set to the preimage when NewState is settled.
Preimage lntypes.Preimage Preimage *lntypes.Preimage
} }
// InvoiceUpdateCallback is a callback used in the db transaction to update the // InvoiceUpdateCallback is a callback used in the db transaction to update the
// invoice. // invoice.
type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error) type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error)
func validateInvoice(i *Invoice) error { func validateInvoice(i *Invoice, paymentHash lntypes.Hash) error {
// Avoid conflicts with all-zeroes magic value in the database.
if paymentHash == unknownPreimage.Hash() {
return fmt.Errorf("cannot use hash of all-zeroes preimage")
}
if len(i.Memo) > MaxMemoSize { if len(i.Memo) > MaxMemoSize {
return fmt.Errorf("max length a memo is %v, and invoice "+ return fmt.Errorf("max length a memo is %v, and invoice "+
"of length %v was provided", MaxMemoSize, len(i.Memo)) "of length %v was provided", MaxMemoSize, len(i.Memo))
@ -459,6 +469,10 @@ func validateInvoice(i *Invoice) error {
if i.Terms.Features == nil { if i.Terms.Features == nil {
return errors.New("invoice must have a feature vector") return errors.New("invoice must have a feature vector")
} }
if i.Terms.PaymentPreimage == nil && !i.HodlInvoice {
return errors.New("non-hodl invoices must have a preimage")
}
return nil return nil
} }
@ -475,7 +489,7 @@ func (i *Invoice) IsPending() bool {
func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) (
uint64, error) { uint64, error) {
if err := validateInvoice(newInvoice); err != nil { if err := validateInvoice(newInvoice, paymentHash); err != nil {
return 0, err return 0, err
} }
@ -1131,7 +1145,13 @@ func serializeInvoice(w io.Writer, i *Invoice) error {
} }
featureBytes := fb.Bytes() featureBytes := fb.Bytes()
preimage := [32]byte(i.Terms.PaymentPreimage) preimage := [32]byte(unknownPreimage)
if i.Terms.PaymentPreimage != nil {
preimage = *i.Terms.PaymentPreimage
if preimage == unknownPreimage {
return errors.New("cannot use all-zeroes preimage")
}
}
value := uint64(i.Terms.Value) value := uint64(i.Terms.Value)
cltvDelta := uint32(i.Terms.FinalCltvDelta) cltvDelta := uint32(i.Terms.FinalCltvDelta)
expiry := uint64(i.Terms.Expiry) expiry := uint64(i.Terms.Expiry)
@ -1139,6 +1159,11 @@ func serializeInvoice(w io.Writer, i *Invoice) error {
amtPaid := uint64(i.AmtPaid) amtPaid := uint64(i.AmtPaid)
state := uint8(i.State) state := uint8(i.State)
var hodlInvoice uint8
if i.HodlInvoice {
hodlInvoice = 1
}
tlvStream, err := tlv.NewStream( tlvStream, err := tlv.NewStream(
// Memo and payreq. // Memo and payreq.
tlv.MakePrimitiveRecord(memoType, &i.Memo), tlv.MakePrimitiveRecord(memoType, &i.Memo),
@ -1161,6 +1186,8 @@ func serializeInvoice(w io.Writer, i *Invoice) error {
// Invoice state. // Invoice state.
tlv.MakePrimitiveRecord(invStateType, &state), tlv.MakePrimitiveRecord(invStateType, &state),
tlv.MakePrimitiveRecord(amtPaidType, &amtPaid), tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
) )
if err != nil { if err != nil {
return err return err
@ -1256,12 +1283,13 @@ func fetchInvoice(invoiceNum []byte, invoices kvdb.RBucket) (Invoice, error) {
func deserializeInvoice(r io.Reader) (Invoice, error) { func deserializeInvoice(r io.Reader) (Invoice, error) {
var ( var (
preimage [32]byte preimageBytes [32]byte
value uint64 value uint64
cltvDelta uint32 cltvDelta uint32
expiry uint64 expiry uint64
amtPaid uint64 amtPaid uint64
state uint8 state uint8
hodlInvoice uint8
creationDateBytes []byte creationDateBytes []byte
settleDateBytes []byte settleDateBytes []byte
@ -1281,7 +1309,7 @@ func deserializeInvoice(r io.Reader) (Invoice, error) {
tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex), tlv.MakePrimitiveRecord(settleIndexType, &i.SettleIndex),
// Terms. // Terms.
tlv.MakePrimitiveRecord(preimageType, &preimage), tlv.MakePrimitiveRecord(preimageType, &preimageBytes),
tlv.MakePrimitiveRecord(valueType, &value), tlv.MakePrimitiveRecord(valueType, &value),
tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta), tlv.MakePrimitiveRecord(cltvDeltaType, &cltvDelta),
tlv.MakePrimitiveRecord(expiryType, &expiry), tlv.MakePrimitiveRecord(expiryType, &expiry),
@ -1291,6 +1319,8 @@ func deserializeInvoice(r io.Reader) (Invoice, error) {
// Invoice state. // Invoice state.
tlv.MakePrimitiveRecord(invStateType, &state), tlv.MakePrimitiveRecord(invStateType, &state),
tlv.MakePrimitiveRecord(amtPaidType, &amtPaid), tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
) )
if err != nil { if err != nil {
return i, err return i, err
@ -1307,13 +1337,21 @@ func deserializeInvoice(r io.Reader) (Invoice, error) {
return i, err return i, err
} }
i.Terms.PaymentPreimage = lntypes.Preimage(preimage) preimage := lntypes.Preimage(preimageBytes)
if preimage != unknownPreimage {
i.Terms.PaymentPreimage = &preimage
}
i.Terms.Value = lnwire.MilliSatoshi(value) i.Terms.Value = lnwire.MilliSatoshi(value)
i.Terms.FinalCltvDelta = int32(cltvDelta) i.Terms.FinalCltvDelta = int32(cltvDelta)
i.Terms.Expiry = time.Duration(expiry) i.Terms.Expiry = time.Duration(expiry)
i.AmtPaid = lnwire.MilliSatoshi(amtPaid) i.AmtPaid = lnwire.MilliSatoshi(amtPaid)
i.State = ContractState(state) i.State = ContractState(state)
if hodlInvoice != 0 {
i.HodlInvoice = true
}
err = i.CreationDate.UnmarshalBinary(creationDateBytes) err = i.CreationDate.UnmarshalBinary(creationDateBytes)
if err != nil { if err != nil {
return i, err return i, err
@ -1443,10 +1481,16 @@ func copyInvoice(src *Invoice) *Invoice {
Htlcs: make( Htlcs: make(
map[CircuitKey]*InvoiceHTLC, len(src.Htlcs), map[CircuitKey]*InvoiceHTLC, len(src.Htlcs),
), ),
HodlInvoice: src.HodlInvoice,
} }
dest.Terms.Features = src.Terms.Features.Clone() dest.Terms.Features = src.Terms.Features.Clone()
if src.Terms.PaymentPreimage != nil {
preimage := *src.Terms.PaymentPreimage
dest.Terms.PaymentPreimage = &preimage
}
for k, v := range src.Htlcs { for k, v := range src.Htlcs {
dest.Htlcs[k] = copyInvoiceHTLC(v) dest.Htlcs[k] = copyInvoiceHTLC(v)
} }
@ -1619,10 +1663,16 @@ func updateInvoiceState(invoice *Invoice, hash lntypes.Hash,
case ContractOpen: case ContractOpen:
if update.NewState == ContractSettled { if update.NewState == ContractSettled {
// Validate preimage. // Validate preimage.
switch {
case update.Preimage != nil:
if update.Preimage.Hash() != hash { if update.Preimage.Hash() != hash {
return ErrInvoicePreimageMismatch return ErrInvoicePreimageMismatch
} }
invoice.Terms.PaymentPreimage = update.Preimage invoice.Terms.PaymentPreimage = update.Preimage
case invoice.Terms.PaymentPreimage == nil:
return errors.New("unknown preimage")
}
} }
// Once settled, we are in a terminal state. // Once settled, we are in a terminal state.

@ -1457,8 +1457,7 @@ func (c *ChannelArbitrator) isPreimageAvailable(hash lntypes.Hash) (bool,
return false, err return false, err
} }
preimageAvailable = invoice.Terms.PaymentPreimage != preimageAvailable = invoice.Terms.PaymentPreimage != nil
channeldb.UnknownPreimage
return preimageAvailable, nil return preimageAvailable, nil
} }

@ -437,7 +437,9 @@ func TestChannelLinkCancelFullCommitment(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
preimages[i] = lntypes.Preimage{byte(i >> 8), byte(i)} // Deterministically generate preimages. Avoid the all-zeroes
// preimage because that will be rejected by the database.
preimages[i] = lntypes.Preimage{byte(i >> 8), byte(i), 1}
wg.Add(1) wg.Add(1)
go func(i int) { go func(i int) {
@ -2015,13 +2017,13 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
// If we now send in a valid HTLC settle for the prior HTLC we added, // If we now send in a valid HTLC settle for the prior HTLC we added,
// then the bandwidth should remain unchanged as the remote party will // then the bandwidth should remain unchanged as the remote party will
// gain additional channel balance. // gain additional channel balance.
err = bobChannel.SettleHTLC(invoice.Terms.PaymentPreimage, bobIndex, nil, nil, nil) err = bobChannel.SettleHTLC(*invoice.Terms.PaymentPreimage, bobIndex, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("unable to settle htlc: %v", err) t.Fatalf("unable to settle htlc: %v", err)
} }
htlcSettle := &lnwire.UpdateFulfillHTLC{ htlcSettle := &lnwire.UpdateFulfillHTLC{
ID: 0, ID: 0,
PaymentPreimage: invoice.Terms.PaymentPreimage, PaymentPreimage: *invoice.Terms.PaymentPreimage,
} }
aliceLink.HandleChannelUpdate(htlcSettle) aliceLink.HandleChannelUpdate(htlcSettle)
time.Sleep(time.Millisecond * 500) time.Sleep(time.Millisecond * 500)
@ -2193,7 +2195,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
outgoingHTLCID: addPkt.outgoingHTLCID, outgoingHTLCID: addPkt.outgoingHTLCID,
htlc: &lnwire.UpdateFulfillHTLC{ htlc: &lnwire.UpdateFulfillHTLC{
ID: 0, ID: 0,
PaymentPreimage: invoice.Terms.PaymentPreimage, PaymentPreimage: *invoice.Terms.PaymentPreimage,
}, },
obfuscator: NewMockObfuscator(), obfuscator: NewMockObfuscator(),
} }
@ -3153,13 +3155,13 @@ func TestChannelLinkBandwidthChanReserve(t *testing.T) {
// If we now send in a valid HTLC settle for the prior HTLC we added, // If we now send in a valid HTLC settle for the prior HTLC we added,
// then the bandwidth should remain unchanged as the remote party will // then the bandwidth should remain unchanged as the remote party will
// gain additional channel balance. // gain additional channel balance.
err = bobChannel.SettleHTLC(invoice.Terms.PaymentPreimage, bobIndex, nil, nil, nil) err = bobChannel.SettleHTLC(*invoice.Terms.PaymentPreimage, bobIndex, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("unable to settle htlc: %v", err) t.Fatalf("unable to settle htlc: %v", err)
} }
htlcSettle := &lnwire.UpdateFulfillHTLC{ htlcSettle := &lnwire.UpdateFulfillHTLC{
ID: bobIndex, ID: bobIndex,
PaymentPreimage: invoice.Terms.PaymentPreimage, PaymentPreimage: *invoice.Terms.PaymentPreimage,
} }
aliceLink.HandleChannelUpdate(htlcSettle) aliceLink.HandleChannelUpdate(htlcSettle)
time.Sleep(time.Millisecond * 500) time.Sleep(time.Millisecond * 500)
@ -4730,7 +4732,7 @@ func testChannelLinkBatchPreimageWrite(t *testing.T, disconnect bool) {
for i, invoice := range invoices { for i, invoice := range invoices {
ctx.sendSettleBobToAlice( ctx.sendSettleBobToAlice(
uint64(i), uint64(i),
invoice.Terms.PaymentPreimage, *invoice.Terms.PaymentPreimage,
) )
} }
@ -5772,7 +5774,8 @@ func TestChannelLinkHoldInvoiceRestart(t *testing.T) {
// Convert into a hodl invoice and save the preimage for later. // Convert into a hodl invoice and save the preimage for later.
preimage := invoice.Terms.PaymentPreimage preimage := invoice.Terms.PaymentPreimage
invoice.Terms.PaymentPreimage = channeldb.UnknownPreimage invoice.Terms.PaymentPreimage = nil
invoice.HodlInvoice = true
// We must add the invoice to the registry, such that Alice // We must add the invoice to the registry, such that Alice
// expects this payment. // expects this payment.
@ -5814,7 +5817,10 @@ func TestChannelLinkHoldInvoiceRestart(t *testing.T) {
<-registry.settleChan <-registry.settleChan
// Settle the invoice with the preimage. // Settle the invoice with the preimage.
registry.SettleHodlInvoice(preimage) err = registry.SettleHodlInvoice(*preimage)
if err != nil {
t.Fatalf("settle hodl invoice: %v", err)
}
// Expect alice to send a settle and commitsig message to bob. // Expect alice to send a settle and commitsig message to bob.
ctx.receiveSettleAliceToBob() ctx.receiveSettleAliceToBob()
@ -5957,10 +5963,12 @@ func TestChannelLinkRevocationWindowHodl(t *testing.T) {
// Convert into hodl invoices and save the preimages for later. // Convert into hodl invoices and save the preimages for later.
preimage1 := invoice1.Terms.PaymentPreimage preimage1 := invoice1.Terms.PaymentPreimage
invoice1.Terms.PaymentPreimage = channeldb.UnknownPreimage invoice1.Terms.PaymentPreimage = nil
invoice1.HodlInvoice = true
preimage2 := invoice2.Terms.PaymentPreimage preimage2 := invoice2.Terms.PaymentPreimage
invoice2.Terms.PaymentPreimage = channeldb.UnknownPreimage invoice2.Terms.PaymentPreimage = nil
invoice2.HodlInvoice = true
// We must add the invoices to the registry, such that Alice // We must add the invoices to the registry, such that Alice
// expects the payments. // expects the payments.
@ -6009,7 +6017,10 @@ func TestChannelLinkRevocationWindowHodl(t *testing.T) {
} }
// Settle invoice 1 with the preimage. // Settle invoice 1 with the preimage.
registry.SettleHodlInvoice(preimage1) err = registry.SettleHodlInvoice(*preimage1)
if err != nil {
t.Fatalf("settle hodl invoice: %v", err)
}
// Expect alice to send a settle and commitsig message to bob. Bob does // Expect alice to send a settle and commitsig message to bob. Bob does
// not yet send the revocation. // not yet send the revocation.
@ -6017,7 +6028,10 @@ func TestChannelLinkRevocationWindowHodl(t *testing.T) {
ctx.receiveCommitSigAliceToBob(1) ctx.receiveCommitSigAliceToBob(1)
// Settle invoice 2 with the preimage. // Settle invoice 2 with the preimage.
registry.SettleHodlInvoice(preimage2) err = registry.SettleHodlInvoice(*preimage2)
if err != nil {
t.Fatalf("settle hodl invoice: %v", err)
}
// Expect alice to send a settle for htlc 2. // Expect alice to send a settle for htlc 2.
ctx.receiveSettleAliceToBob() ctx.receiveSettleAliceToBob()

@ -547,8 +547,8 @@ func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) {
// invoice which should be added by destination peer. // invoice which should be added by destination peer.
func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi, func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi,
timelock uint32, blob [lnwire.OnionPacketSize]byte, timelock uint32, blob [lnwire.OnionPacketSize]byte,
preimage, rhash, payAddr [32]byte) (*channeldb.Invoice, *lnwire.UpdateAddHTLC, preimage *lntypes.Preimage, rhash, payAddr [32]byte) (
uint64, error) { *channeldb.Invoice, *lnwire.UpdateAddHTLC, uint64, error) {
// Create the db invoice. Normally the payment requests needs to be set, // Create the db invoice. Normally the payment requests needs to be set,
// because it is decoded in InvoiceRegistry to obtain the cltv expiry. // because it is decoded in InvoiceRegistry to obtain the cltv expiry.
@ -556,6 +556,7 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi,
// step and always returning the value of testInvoiceCltvExpiry, we // step and always returning the value of testInvoiceCltvExpiry, we
// don't need to bother here with creating and signing a payment // don't need to bother here with creating and signing a payment
// request. // request.
invoice := &channeldb.Invoice{ invoice := &channeldb.Invoice{
CreationDate: time.Now(), CreationDate: time.Now(),
Terms: channeldb.ContractTerm{ Terms: channeldb.ContractTerm{
@ -567,6 +568,7 @@ func generatePaymentWithPreimage(invoiceAmt, htlcAmt lnwire.MilliSatoshi,
nil, lnwire.Features, nil, lnwire.Features,
), ),
}, },
HodlInvoice: preimage == nil,
} }
htlc := &lnwire.UpdateAddHTLC{ htlc := &lnwire.UpdateAddHTLC{
@ -591,7 +593,7 @@ func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32,
blob [lnwire.OnionPacketSize]byte) (*channeldb.Invoice, blob [lnwire.OnionPacketSize]byte) (*channeldb.Invoice,
*lnwire.UpdateAddHTLC, uint64, error) { *lnwire.UpdateAddHTLC, uint64, error) {
var preimage [sha256.Size]byte var preimage lntypes.Preimage
r, err := generateRandomBytes(sha256.Size) r, err := generateRandomBytes(sha256.Size)
if err != nil { if err != nil {
return nil, nil, 0, err return nil, nil, 0, err
@ -608,7 +610,7 @@ func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32,
copy(payAddr[:], r) copy(payAddr[:], r)
return generatePaymentWithPreimage( return generatePaymentWithPreimage(
invoiceAmt, htlcAmt, timelock, blob, preimage, rhash, payAddr, invoiceAmt, htlcAmt, timelock, blob, &preimage, rhash, payAddr,
) )
} }
@ -1345,7 +1347,7 @@ func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer,
// Generate payment: invoice and htlc. // Generate payment: invoice and htlc.
invoice, htlc, pid, err := generatePaymentWithPreimage( invoice, htlc, pid, err := generatePaymentWithPreimage(
invoiceAmt, htlcAmt, timelock, blob, invoiceAmt, htlcAmt, timelock, blob,
channeldb.UnknownPreimage, rhash, payAddr, nil, rhash, payAddr,
) )
if err != nil { if err != nil {
paymentErr <- err paymentErr <- err

@ -652,12 +652,6 @@ func (i *InvoiceRegistry) processKeySend(ctx invoiceUpdateCtx) error {
return errors.New("invalid keysend preimage") return errors.New("invalid keysend preimage")
} }
// Don't accept zero preimages as those have a special meaning in our
// database for hodl invoices.
if preimage == channeldb.UnknownPreimage {
return errors.New("invalid keysend preimage")
}
// Only allow keysend for non-mpp payments. // Only allow keysend for non-mpp payments.
if ctx.mpp != nil { if ctx.mpp != nil {
return errors.New("no mpp keysend supported") return errors.New("no mpp keysend supported")
@ -688,7 +682,7 @@ func (i *InvoiceRegistry) processKeySend(ctx invoiceUpdateCtx) error {
Terms: channeldb.ContractTerm{ Terms: channeldb.ContractTerm{
FinalCltvDelta: finalCltvDelta, FinalCltvDelta: finalCltvDelta,
Value: amt, Value: amt,
PaymentPreimage: preimage, PaymentPreimage: &preimage,
Features: features, Features: features,
}, },
} }
@ -948,7 +942,7 @@ func (i *InvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error {
return &channeldb.InvoiceUpdateDesc{ return &channeldb.InvoiceUpdateDesc{
State: &channeldb.InvoiceStateUpdateDesc{ State: &channeldb.InvoiceStateUpdateDesc{
NewState: channeldb.ContractSettled, NewState: channeldb.ContractSettled,
Preimage: preimage, Preimage: &preimage,
}, },
}, nil }, nil
} }

@ -92,7 +92,7 @@ var (
testInvoiceAmt = lnwire.MilliSatoshi(100000) testInvoiceAmt = lnwire.MilliSatoshi(100000)
testInvoice = &channeldb.Invoice{ testInvoice = &channeldb.Invoice{
Terms: channeldb.ContractTerm{ Terms: channeldb.ContractTerm{
PaymentPreimage: testInvoicePreimage, PaymentPreimage: &testInvoicePreimage,
Value: testInvoiceAmt, Value: testInvoiceAmt,
Expiry: time.Hour, Expiry: time.Hour,
Features: testFeatures, Features: testFeatures,
@ -102,12 +102,12 @@ var (
testHodlInvoice = &channeldb.Invoice{ testHodlInvoice = &channeldb.Invoice{
Terms: channeldb.ContractTerm{ Terms: channeldb.ContractTerm{
PaymentPreimage: channeldb.UnknownPreimage,
Value: testInvoiceAmt, Value: testInvoiceAmt,
Expiry: time.Hour, Expiry: time.Hour,
Features: testFeatures, Features: testFeatures,
}, },
CreationDate: testInvoiceCreationDate, CreationDate: testInvoiceCreationDate,
HodlInvoice: true,
} }
) )
@ -225,7 +225,7 @@ func newTestInvoice(t *testing.T, preimage lntypes.Preimage,
return &channeldb.Invoice{ return &channeldb.Invoice{
Terms: channeldb.ContractTerm{ Terms: channeldb.ContractTerm{
PaymentPreimage: preimage, PaymentPreimage: &preimage,
PaymentAddr: payAddr, PaymentAddr: payAddr,
Value: testInvoiceAmount, Value: testInvoiceAmount,
Expiry: expiry, Expiry: expiry,

@ -81,7 +81,7 @@ func updateInvoice(ctx *invoiceUpdateCtx, inv *channeldb.Invoice) (
case channeldb.HtlcStateSettled: case channeldb.HtlcStateSettled:
return nil, ctx.settleRes( return nil, ctx.settleRes(
inv.Terms.PaymentPreimage, *inv.Terms.PaymentPreimage,
ResultReplayToSettled, ResultReplayToSettled,
), nil ), nil
@ -187,8 +187,7 @@ func updateMpp(ctx *invoiceUpdateCtx,
// Check to see if we can settle or this is an hold invoice and // Check to see if we can settle or this is an hold invoice and
// we need to wait for the preimage. // we need to wait for the preimage.
holdInvoice := inv.Terms.PaymentPreimage == channeldb.UnknownPreimage if inv.HodlInvoice {
if holdInvoice {
update.State = &channeldb.InvoiceStateUpdateDesc{ update.State = &channeldb.InvoiceStateUpdateDesc{
NewState: channeldb.ContractAccepted, NewState: channeldb.ContractAccepted,
} }
@ -201,7 +200,7 @@ func updateMpp(ctx *invoiceUpdateCtx,
} }
return &update, ctx.settleRes( return &update, ctx.settleRes(
inv.Terms.PaymentPreimage, ResultSettled, *inv.Terms.PaymentPreimage, ResultSettled,
), nil ), nil
} }
@ -269,14 +268,13 @@ func updateLegacy(ctx *invoiceUpdateCtx,
case channeldb.ContractSettled: case channeldb.ContractSettled:
return &update, ctx.settleRes( return &update, ctx.settleRes(
inv.Terms.PaymentPreimage, ResultDuplicateToSettled, *inv.Terms.PaymentPreimage, ResultDuplicateToSettled,
), nil ), nil
} }
// Check to see if we can settle or this is an hold invoice and we need // Check to see if we can settle or this is an hold invoice and we need
// to wait for the preimage. // to wait for the preimage.
holdInvoice := inv.Terms.PaymentPreimage == channeldb.UnknownPreimage if inv.HodlInvoice {
if holdInvoice {
update.State = &channeldb.InvoiceStateUpdateDesc{ update.State = &channeldb.InvoiceStateUpdateDesc{
NewState: channeldb.ContractAccepted, NewState: channeldb.ContractAccepted,
} }
@ -290,6 +288,6 @@ func updateLegacy(ctx *invoiceUpdateCtx,
} }
return &update, ctx.settleRes( return &update, ctx.settleRes(
inv.Terms.PaymentPreimage, ResultSettled, *inv.Terms.PaymentPreimage, ResultSettled,
), nil ), nil
} }

@ -88,6 +88,10 @@ type AddInvoiceData struct {
// Whether this invoice should include routing hints for private // Whether this invoice should include routing hints for private
// channels. // channels.
Private bool Private bool
// HodlInvoice signals that this invoice shouldn't be settled
// immediately upon receiving the payment.
HodlInvoice bool
} }
// AddInvoice attempts to add a new invoice to the invoice database. Any // AddInvoice attempts to add a new invoice to the invoice database. Any
@ -97,7 +101,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,
invoice *AddInvoiceData) (*lntypes.Hash, *channeldb.Invoice, error) { invoice *AddInvoiceData) (*lntypes.Hash, *channeldb.Invoice, error) {
var ( var (
paymentPreimage lntypes.Preimage paymentPreimage *lntypes.Preimage
paymentHash lntypes.Hash paymentHash lntypes.Hash
) )
@ -108,26 +112,9 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,
return nil, nil, return nil, nil,
errors.New("preimage and hash both set") errors.New("preimage and hash both set")
// Prevent the unknown preimage magic value from being used for a
// regular invoice. This would cause the invoice the be handled as if it
// was a hold invoice.
case invoice.Preimage != nil &&
*invoice.Preimage == channeldb.UnknownPreimage:
return nil, nil,
fmt.Errorf("cannot use all zeroes as a preimage")
// Prevent the hash of the unknown preimage magic value to be used for a
// hold invoice. This would make it impossible to settle the invoice,
// because it would still be interpreted as not having a preimage.
case invoice.Hash != nil &&
*invoice.Hash == channeldb.UnknownPreimage.Hash():
return nil, nil,
fmt.Errorf("cannot use hash of all zeroes preimage")
// If no hash or preimage is given, generate a random preimage. // If no hash or preimage is given, generate a random preimage.
case invoice.Preimage == nil && invoice.Hash == nil: case invoice.Preimage == nil && invoice.Hash == nil:
paymentPreimage = &lntypes.Preimage{}
if _, err := rand.Read(paymentPreimage[:]); err != nil { if _, err := rand.Read(paymentPreimage[:]); err != nil {
return nil, nil, err return nil, nil, err
} }
@ -136,12 +123,12 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,
// If just a hash is given, we create a hold invoice by setting the // If just a hash is given, we create a hold invoice by setting the
// preimage to unknown. // preimage to unknown.
case invoice.Preimage == nil && invoice.Hash != nil: case invoice.Preimage == nil && invoice.Hash != nil:
paymentPreimage = channeldb.UnknownPreimage
paymentHash = *invoice.Hash paymentHash = *invoice.Hash
// A specific preimage was supplied. Use that for the invoice. // A specific preimage was supplied. Use that for the invoice.
case invoice.Preimage != nil && invoice.Hash == nil: case invoice.Preimage != nil && invoice.Hash == nil:
paymentPreimage = *invoice.Preimage preimage := *invoice.Preimage
paymentPreimage = &preimage
paymentHash = invoice.Preimage.Hash() paymentHash = invoice.Preimage.Hash()
} }
@ -410,6 +397,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,
PaymentAddr: paymentAddr, PaymentAddr: paymentAddr,
Features: invoiceFeatures, Features: invoiceFeatures,
}, },
HodlInvoice: invoice.HodlInvoice,
} }
log.Tracef("[addinvoice] adding new invoice %v", log.Tracef("[addinvoice] adding new invoice %v",

@ -274,6 +274,8 @@ func (s *Server) AddHoldInvoice(ctx context.Context,
FallbackAddr: invoice.FallbackAddr, FallbackAddr: invoice.FallbackAddr,
CltvExpiry: invoice.CltvExpiry, CltvExpiry: invoice.CltvExpiry,
Private: invoice.Private, Private: invoice.Private,
HodlInvoice: true,
Preimage: nil,
} }
_, dbInvoice, err := AddInvoice(ctx, addInvoiceCfg, addInvoiceData) _, dbInvoice, err := AddInvoice(ctx, addInvoiceCfg, addInvoiceData)

@ -22,7 +22,7 @@ func decodePayReq(invoice *channeldb.Invoice,
paymentRequest := string(invoice.PaymentRequest) paymentRequest := string(invoice.PaymentRequest)
if paymentRequest == "" { if paymentRequest == "" {
preimage := invoice.Terms.PaymentPreimage preimage := invoice.Terms.PaymentPreimage
if preimage == channeldb.UnknownPreimage { if preimage == nil {
return nil, errors.New("cannot reconstruct pay req") return nil, errors.New("cannot reconstruct pay req")
} }
hash := [32]byte(preimage.Hash()) hash := [32]byte(preimage.Hash())
@ -149,7 +149,7 @@ func CreateRPCInvoice(invoice *channeldb.Invoice,
IsKeysend: len(invoice.PaymentRequest) == 0, IsKeysend: len(invoice.PaymentRequest) == 0,
} }
if preimage != channeldb.UnknownPreimage { if preimage != nil {
rpcInvoice.RPreimage = preimage[:] rpcInvoice.RPreimage = preimage[:]
} }