channeldb+invoices: add invoice htlcs
This commit adds a set of htlcs to the Invoice struct and serializes/deserializes this set to/from disk. It is a preparation for accurate invoice accounting across restarts of lnd. A migration is added for the invoice htlcs. In addition to these changes, separate final cltv delta and expiry invoice fields are created and populated. Previously it was required to decode this from the stored payment request. The reason to create a combined commit is to prevent multiple migrations.
This commit is contained in:
parent
061b34b924
commit
4105142c96
@ -110,6 +110,11 @@ var (
|
||||
number: 10,
|
||||
migration: migrateRouteSerialization,
|
||||
},
|
||||
{
|
||||
// Add invoice htlc and cltv delta fields.
|
||||
number: 11,
|
||||
migration: migrateInvoices,
|
||||
},
|
||||
}
|
||||
|
||||
// Big endian is the preferred byte order, due to cursor scans over
|
||||
|
@ -10,6 +10,26 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
var (
|
||||
testCircuitKey = CircuitKey{
|
||||
ChanID: lnwire.ShortChannelID{
|
||||
BlockHeight: 1, TxIndex: 2, TxPosition: 3,
|
||||
},
|
||||
HtlcID: 4,
|
||||
}
|
||||
|
||||
testHtlcs = map[CircuitKey]*InvoiceHTLC{
|
||||
testCircuitKey: {
|
||||
State: HtlcStateCancelled,
|
||||
AcceptTime: time.Unix(1, 0),
|
||||
AcceptHeight: 100,
|
||||
ResolveTime: time.Unix(2, 0),
|
||||
Amt: 5200,
|
||||
Expiry: 150,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
|
||||
var pre [32]byte
|
||||
if _, err := rand.Read(pre[:]); err != nil {
|
||||
@ -24,6 +44,9 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
|
||||
PaymentPreimage: pre,
|
||||
Value: value,
|
||||
},
|
||||
Htlcs: testHtlcs,
|
||||
FinalCltvDelta: 50,
|
||||
Expiry: 4000,
|
||||
}
|
||||
i.Memo = []byte("memo")
|
||||
i.Receipt = []byte("receipt")
|
||||
@ -59,6 +82,7 @@ func TestInvoiceWorkflow(t *testing.T) {
|
||||
// Use single second precision to avoid false positive test
|
||||
// failures due to the monotonic time component.
|
||||
CreationDate: time.Unix(time.Now().Unix(), 0),
|
||||
Htlcs: testHtlcs,
|
||||
}
|
||||
fakeInvoice.Memo = []byte("memo")
|
||||
fakeInvoice.Receipt = []byte("receipt")
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -92,6 +93,17 @@ const (
|
||||
// TODO(halseth): determine the max length payment request when field
|
||||
// lengths are final.
|
||||
MaxPaymentRequestSize = 4096
|
||||
|
||||
// A set of tlv type definitions used to serialize invoice htlcs to the
|
||||
// database.
|
||||
chanIDType tlv.Type = 1
|
||||
htlcIDType tlv.Type = 3
|
||||
amtType tlv.Type = 5
|
||||
acceptHeightType tlv.Type = 7
|
||||
acceptTimeType tlv.Type = 9
|
||||
resolveTimeType tlv.Type = 11
|
||||
expiryHeightType tlv.Type = 13
|
||||
stateType tlv.Type = 15
|
||||
)
|
||||
|
||||
// ContractState describes the state the invoice is in.
|
||||
@ -172,6 +184,13 @@ type Invoice struct {
|
||||
// for this invoice can be stored.
|
||||
PaymentRequest []byte
|
||||
|
||||
// FinalCltvDelta is the minimum required number of blocks before htlc
|
||||
// expiry when the invoice is accepted.
|
||||
FinalCltvDelta int32
|
||||
|
||||
// Expiry defines how long after creation this invoice should expire.
|
||||
Expiry time.Duration
|
||||
|
||||
// CreationDate is the exact time the invoice was created.
|
||||
CreationDate time.Time
|
||||
|
||||
@ -209,6 +228,52 @@ type Invoice struct {
|
||||
// that the invoice originally didn't specify an amount, or the sender
|
||||
// overpaid.
|
||||
AmtPaid lnwire.MilliSatoshi
|
||||
|
||||
// Htlcs records all htlcs that paid to this invoice. Some of these
|
||||
// htlcs may have been marked as cancelled.
|
||||
Htlcs map[CircuitKey]*InvoiceHTLC
|
||||
}
|
||||
|
||||
// HtlcState defines the states an htlc paying to an invoice can be in.
|
||||
type HtlcState uint8
|
||||
|
||||
const (
|
||||
// HtlcStateAccepted indicates the htlc is locked-in, but not resolved.
|
||||
HtlcStateAccepted HtlcState = iota
|
||||
|
||||
// HtlcStateCancelled indicates the htlc is cancelled back to the
|
||||
// sender.
|
||||
HtlcStateCancelled
|
||||
|
||||
// HtlcStateSettled indicates the htlc is settled.
|
||||
HtlcStateSettled
|
||||
)
|
||||
|
||||
// InvoiceHTLC contains details about an htlc paying to this invoice.
|
||||
type InvoiceHTLC struct {
|
||||
// Amt is the amount that is carried by this htlc.
|
||||
Amt lnwire.MilliSatoshi
|
||||
|
||||
// AcceptHeight is the block height at which the invoice registry
|
||||
// decided to accept this htlc as a payment to the invoice. At this
|
||||
// height, the invoice cltv delay must have been met.
|
||||
AcceptHeight uint32
|
||||
|
||||
// AcceptTime is the wall clock time at which the invoice registry
|
||||
// decided to accept the htlc.
|
||||
AcceptTime time.Time
|
||||
|
||||
// ResolveTime is the wall clock time at which the invoice registry
|
||||
// decided to settle the htlc.
|
||||
ResolveTime time.Time
|
||||
|
||||
// Expiry is the expiry height of this htlc.
|
||||
Expiry uint32
|
||||
|
||||
// State indicates the state the invoice htlc is currently in. A
|
||||
// cancelled htlc isn't just removed from the invoice htlcs map, because
|
||||
// we need AcceptedHeight to properly cancel the htlc back.
|
||||
State HtlcState
|
||||
}
|
||||
|
||||
func validateInvoice(i *Invoice) error {
|
||||
@ -865,6 +930,11 @@ func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket,
|
||||
return nextAddSeqNo, nil
|
||||
}
|
||||
|
||||
// serializeInvoice serializes an invoice to a writer.
|
||||
//
|
||||
// Note: this function is in use for a migration. Before making changes that
|
||||
// would modify the on disk format, make a copy of the original code and store
|
||||
// it with the migration.
|
||||
func serializeInvoice(w io.Writer, i *Invoice) error {
|
||||
if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil {
|
||||
return err
|
||||
@ -876,6 +946,14 @@ func serializeInvoice(w io.Writer, i *Invoice) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, byteOrder, i.FinalCltvDelta); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, byteOrder, int64(i.Expiry)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
birthBytes, err := i.CreationDate.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -918,6 +996,57 @@ func serializeInvoice(w io.Writer, i *Invoice) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := serializeHtlcs(w, i.Htlcs); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to
|
||||
// a writer.
|
||||
func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
|
||||
for key, htlc := range htlcs {
|
||||
// Encode the htlc in a tlv stream.
|
||||
chanID := key.ChanID.ToUint64()
|
||||
amt := uint64(htlc.Amt)
|
||||
acceptTime := uint64(htlc.AcceptTime.UnixNano())
|
||||
resolveTime := uint64(htlc.ResolveTime.UnixNano())
|
||||
state := uint8(htlc.State)
|
||||
|
||||
tlvStream, err := tlv.NewStream(
|
||||
tlv.MakePrimitiveRecord(chanIDType, &chanID),
|
||||
tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
|
||||
tlv.MakePrimitiveRecord(amtType, &amt),
|
||||
tlv.MakePrimitiveRecord(
|
||||
acceptHeightType, &htlc.AcceptHeight,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
|
||||
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
|
||||
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
|
||||
tlv.MakePrimitiveRecord(stateType, &state),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tlvStream.Encode(&b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the length of the tlv stream followed by the stream
|
||||
// bytes.
|
||||
err = binary.Write(w, byteOrder, uint64(b.Len()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := w.Write(b.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -951,6 +1080,16 @@ func deserializeInvoice(r io.Reader) (Invoice, error) {
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
if err := binary.Read(r, byteOrder, &invoice.FinalCltvDelta); err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
var expiry int64
|
||||
if err := binary.Read(r, byteOrder, &expiry); err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
invoice.Expiry = time.Duration(expiry)
|
||||
|
||||
birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth")
|
||||
if err != nil {
|
||||
return invoice, err
|
||||
@ -990,9 +1129,77 @@ func deserializeInvoice(r io.Reader) (Invoice, error) {
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
invoice.Htlcs, err = deserializeHtlcs(r)
|
||||
if err != nil {
|
||||
return Invoice{}, err
|
||||
}
|
||||
|
||||
return invoice, nil
|
||||
}
|
||||
|
||||
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
|
||||
// as a map.
|
||||
func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
|
||||
htlcs := make(map[CircuitKey]*InvoiceHTLC, 0)
|
||||
|
||||
for {
|
||||
// Read the length of the tlv stream for this htlc.
|
||||
var streamLen uint64
|
||||
if err := binary.Read(r, byteOrder, &streamLen); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
streamBytes := make([]byte, streamLen)
|
||||
if _, err := r.Read(streamBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
streamReader := bytes.NewReader(streamBytes)
|
||||
|
||||
// Decode the contents into the htlc fields.
|
||||
var (
|
||||
htlc InvoiceHTLC
|
||||
key CircuitKey
|
||||
chanID uint64
|
||||
state uint8
|
||||
acceptTime, resolveTime uint64
|
||||
amt uint64
|
||||
)
|
||||
tlvStream, err := tlv.NewStream(
|
||||
tlv.MakePrimitiveRecord(chanIDType, &chanID),
|
||||
tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
|
||||
tlv.MakePrimitiveRecord(amtType, &amt),
|
||||
tlv.MakePrimitiveRecord(
|
||||
acceptHeightType, &htlc.AcceptHeight,
|
||||
),
|
||||
tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime),
|
||||
tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime),
|
||||
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
|
||||
tlv.MakePrimitiveRecord(stateType, &state),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tlvStream.Decode(streamReader); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key.ChanID = lnwire.NewShortChanIDFromInt(chanID)
|
||||
htlc.AcceptTime = time.Unix(0, int64(acceptTime))
|
||||
htlc.ResolveTime = time.Unix(0, int64(resolveTime))
|
||||
htlc.State = HtlcState(state)
|
||||
htlc.Amt = lnwire.MilliSatoshi(amt)
|
||||
|
||||
htlcs[key] = &htlc
|
||||
}
|
||||
|
||||
return htlcs, nil
|
||||
}
|
||||
|
||||
func acceptOrSettleInvoice(invoices, settleIndex *bbolt.Bucket,
|
||||
invoiceNum []byte, amtPaid lnwire.MilliSatoshi,
|
||||
checkHtlcParameters func(invoice *Invoice) error) (
|
||||
|
@ -177,7 +177,7 @@ func fetchPaymentStatusTx(tx *bbolt.Tx, paymentHash [32]byte) (PaymentStatus, er
|
||||
func serializeOutgoingPayment(w io.Writer, p *outgoingPayment) error {
|
||||
var scratch [8]byte
|
||||
|
||||
if err := serializeInvoice(w, &p.Invoice); err != nil {
|
||||
if err := serializeInvoiceLegacy(w, &p.Invoice); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -218,7 +218,7 @@ func deserializeOutgoingPayment(r io.Reader) (*outgoingPayment, error) {
|
||||
|
||||
p := &outgoingPayment{}
|
||||
|
||||
inv, err := deserializeInvoice(r)
|
||||
inv, err := deserializeInvoiceLegacy(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
225
channeldb/migration_11_invoices.go
Normal file
225
channeldb/migration_11_invoices.go
Normal file
@ -0,0 +1,225 @@
|
||||
package channeldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
bitcoinCfg "github.com/btcsuite/btcd/chaincfg"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/zpay32"
|
||||
litecoinCfg "github.com/ltcsuite/ltcd/chaincfg"
|
||||
)
|
||||
|
||||
// migrateInvoices adds invoice htlcs and a separate cltv delta field to the
|
||||
// invoices.
|
||||
func migrateInvoices(tx *bbolt.Tx) error {
|
||||
log.Infof("Migrating invoices to new invoice format")
|
||||
|
||||
invoiceB := tx.Bucket(invoiceBucket)
|
||||
if invoiceB == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Iterate through the entire key space of the top-level invoice bucket.
|
||||
// If key with a non-nil value stores the next invoice ID which maps to
|
||||
// the corresponding invoice. Store those keys first, because it isn't
|
||||
// safe to modify the bucket inside a ForEach loop.
|
||||
var invoiceKeys [][]byte
|
||||
err := invoiceB.ForEach(func(k, v []byte) error {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
invoiceKeys = append(invoiceKeys, k)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nets := []*bitcoinCfg.Params{
|
||||
&bitcoinCfg.MainNetParams, &bitcoinCfg.SimNetParams,
|
||||
&bitcoinCfg.RegressionNetParams, &bitcoinCfg.TestNet3Params,
|
||||
}
|
||||
|
||||
ltcNets := []*litecoinCfg.Params{
|
||||
&litecoinCfg.MainNetParams, &litecoinCfg.SimNetParams,
|
||||
&litecoinCfg.RegressionNetParams, &litecoinCfg.TestNet4Params,
|
||||
}
|
||||
for _, net := range ltcNets {
|
||||
var convertedNet bitcoinCfg.Params
|
||||
convertedNet.Bech32HRPSegwit = net.Bech32HRPSegwit
|
||||
nets = append(nets, &convertedNet)
|
||||
}
|
||||
|
||||
// Iterate over all stored keys and migrate the invoices.
|
||||
for _, k := range invoiceKeys {
|
||||
v := invoiceB.Get(k)
|
||||
|
||||
// Deserialize the invoice with the deserializing function that
|
||||
// was in use for this version of the database.
|
||||
invoiceReader := bytes.NewReader(v)
|
||||
invoice, err := deserializeInvoiceLegacy(invoiceReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to decode the payment request for every possible net to
|
||||
// avoid passing a the active network to channeldb. This would
|
||||
// be a layering violation, while this migration is only running
|
||||
// once and will likely be removed in the future.
|
||||
var payReq *zpay32.Invoice
|
||||
for _, net := range nets {
|
||||
payReq, err = zpay32.Decode(
|
||||
string(invoice.PaymentRequest), net,
|
||||
)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if payReq == nil {
|
||||
return fmt.Errorf("cannot decode payreq")
|
||||
}
|
||||
invoice.FinalCltvDelta = int32(payReq.MinFinalCLTVExpiry())
|
||||
invoice.Expiry = payReq.Expiry()
|
||||
|
||||
// Serialize the invoice in the new format and use it to replace
|
||||
// the old invoice in the database.
|
||||
var buf bytes.Buffer
|
||||
if err := serializeInvoice(&buf, &invoice); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = invoiceB.Put(k, buf.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Migration of invoices completed!")
|
||||
return nil
|
||||
}
|
||||
|
||||
func deserializeInvoiceLegacy(r io.Reader) (Invoice, error) {
|
||||
var err error
|
||||
invoice := Invoice{}
|
||||
|
||||
// TODO(roasbeef): use read full everywhere
|
||||
invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "")
|
||||
if err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "")
|
||||
if err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "")
|
||||
if err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth")
|
||||
if err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled")
|
||||
if err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
var scratch [8]byte
|
||||
if _, err := io.ReadFull(r, scratch[:]); err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:]))
|
||||
|
||||
if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil {
|
||||
return invoice, err
|
||||
}
|
||||
|
||||
return invoice, nil
|
||||
}
|
||||
|
||||
// serializeInvoiceLegacy serializes an invoice in the format of the previous db
|
||||
// version.
|
||||
func serializeInvoiceLegacy(w io.Writer, i *Invoice) error {
|
||||
if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
birthBytes, err := i.CreationDate.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settleBytes, err := i.SettleDate.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var scratch [8]byte
|
||||
byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value))
|
||||
if _, err := w.Write(scratch[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, byteOrder, i.Terms.State); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, byteOrder, i.AddIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
166
channeldb/migration_11_invoices_test.go
Normal file
166
channeldb/migration_11_invoices_test.go
Normal file
@ -0,0 +1,166 @@
|
||||
package channeldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
bitcoinCfg "github.com/btcsuite/btcd/chaincfg"
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/zpay32"
|
||||
litecoinCfg "github.com/ltcsuite/ltcd/chaincfg"
|
||||
)
|
||||
|
||||
var (
|
||||
testPrivKeyBytes = []byte{
|
||||
0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf,
|
||||
0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9,
|
||||
0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f,
|
||||
0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90,
|
||||
}
|
||||
|
||||
testCltvDelta = int32(50)
|
||||
)
|
||||
|
||||
// TestMigrateInvoices checks that invoices are migrated correctly.
|
||||
func TestMigrateInvoices(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payReqBtc, err := getPayReq(&bitcoinCfg.MainNetParams)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var ltcNetParams bitcoinCfg.Params
|
||||
ltcNetParams.Bech32HRPSegwit = litecoinCfg.MainNetParams.Bech32HRPSegwit
|
||||
payReqLtc, err := getPayReq(<cNetParams)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
invoices := []Invoice{
|
||||
{
|
||||
PaymentRequest: []byte(payReqBtc),
|
||||
},
|
||||
{
|
||||
PaymentRequest: []byte(payReqLtc),
|
||||
},
|
||||
}
|
||||
|
||||
beforeMigrationFunc := func(d *DB) {
|
||||
err := d.Update(func(tx *bbolt.Tx) error {
|
||||
invoicesBucket, err := tx.CreateBucketIfNotExists(
|
||||
invoiceBucket,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
invoiceNum := uint32(1)
|
||||
for _, invoice := range invoices {
|
||||
var invoiceKey [4]byte
|
||||
byteOrder.PutUint32(invoiceKey[:], invoiceNum)
|
||||
invoiceNum++
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := serializeInvoiceLegacy(&buf, &invoice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = invoicesBucket.Put(
|
||||
invoiceKey[:], buf.Bytes(),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that all invoices were migrated.
|
||||
afterMigrationFunc := func(d *DB) {
|
||||
meta, err := d.FetchMeta(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if meta.DbVersionNumber != 1 {
|
||||
t.Fatal("migration 'invoices' wasn't applied")
|
||||
}
|
||||
|
||||
dbInvoices, err := d.FetchAllInvoices(false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch invoices: %v", err)
|
||||
}
|
||||
|
||||
if len(invoices) != len(dbInvoices) {
|
||||
t.Fatalf("expected %d invoices, got %d", len(invoices),
|
||||
len(dbInvoices))
|
||||
}
|
||||
|
||||
for _, dbInvoice := range dbInvoices {
|
||||
if dbInvoice.FinalCltvDelta != testCltvDelta {
|
||||
t.Fatal("incorrect final cltv delta")
|
||||
}
|
||||
if dbInvoice.Expiry != 3600*time.Second {
|
||||
t.Fatal("incorrect expiry")
|
||||
}
|
||||
if len(dbInvoice.Htlcs) != 0 {
|
||||
t.Fatal("expected no htlcs after migration")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
applyMigration(t,
|
||||
beforeMigrationFunc,
|
||||
afterMigrationFunc,
|
||||
migrateInvoices,
|
||||
false)
|
||||
}
|
||||
|
||||
// signDigestCompact generates a test signature to be used in the generation of
|
||||
// test payment requests.
|
||||
func signDigestCompact(hash []byte) ([]byte, error) {
|
||||
// Should the signature reference a compressed public key or not.
|
||||
isCompressedKey := true
|
||||
|
||||
privKey, _ := btcec.PrivKeyFromBytes(btcec.S256(), testPrivKeyBytes)
|
||||
|
||||
// btcec.SignCompact returns a pubkey-recoverable signature
|
||||
sig, err := btcec.SignCompact(
|
||||
btcec.S256(), privKey, hash, isCompressedKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't sign the hash: %v", err)
|
||||
}
|
||||
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
// getPayReq creates a payment request for the given net.
|
||||
func getPayReq(net *bitcoinCfg.Params) (string, error) {
|
||||
options := []func(*zpay32.Invoice){
|
||||
zpay32.CLTVExpiry(uint64(testCltvDelta)),
|
||||
zpay32.Description("test"),
|
||||
}
|
||||
|
||||
payReq, err := zpay32.NewInvoice(
|
||||
net, [32]byte{}, time.Unix(1, 0), options...,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return payReq.Encode(
|
||||
zpay32.MessageSigner{
|
||||
SignCompact: signDigestCompact,
|
||||
},
|
||||
)
|
||||
}
|
@ -168,7 +168,7 @@ func migrateInvoiceTimeSeries(tx *bbolt.Tx) error {
|
||||
invoiceBytesCopy = append(invoiceBytesCopy, padding...)
|
||||
|
||||
invoiceReader := bytes.NewReader(invoiceBytesCopy)
|
||||
invoice, err := deserializeInvoice(invoiceReader)
|
||||
invoice, err := deserializeInvoiceLegacy(invoiceReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to decode invoice: %v", err)
|
||||
}
|
||||
@ -227,7 +227,7 @@ func migrateInvoiceTimeSeries(tx *bbolt.Tx) error {
|
||||
// We've fully migrated an invoice, so we'll now update the
|
||||
// invoice in-place.
|
||||
var b bytes.Buffer
|
||||
if err := serializeInvoice(&b, &invoice); err != nil {
|
||||
if err := serializeInvoiceLegacy(&b, &invoice); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -394,6 +394,8 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,
|
||||
Memo: []byte(invoice.Memo),
|
||||
Receipt: invoice.Receipt,
|
||||
PaymentRequest: []byte(payReqString),
|
||||
FinalCltvDelta: int32(payReq.MinFinalCLTVExpiry()),
|
||||
Expiry: payReq.Expiry(),
|
||||
Terms: channeldb.ContractTerm{
|
||||
Value: amtMSat,
|
||||
PaymentPreimage: paymentPreimage,
|
||||
|
Loading…
Reference in New Issue
Block a user