diff --git a/channeldb/db.go b/channeldb/db.go index 4c96123f..c53b7a23 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -110,6 +110,11 @@ var ( number: 10, migration: migrateRouteSerialization, }, + { + // Add invoice htlc and cltv delta fields. + number: 11, + migration: migrateInvoices, + }, } // Big endian is the preferred byte order, due to cursor scans over diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 288edfef..580fe3d3 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -10,6 +10,26 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) +var ( + testCircuitKey = CircuitKey{ + ChanID: lnwire.ShortChannelID{ + BlockHeight: 1, TxIndex: 2, TxPosition: 3, + }, + HtlcID: 4, + } + + testHtlcs = map[CircuitKey]*InvoiceHTLC{ + testCircuitKey: { + State: HtlcStateCancelled, + AcceptTime: time.Unix(1, 0), + AcceptHeight: 100, + ResolveTime: time.Unix(2, 0), + Amt: 5200, + Expiry: 150, + }, + } +) + func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { var pre [32]byte if _, err := rand.Read(pre[:]); err != nil { @@ -24,6 +44,9 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { PaymentPreimage: pre, Value: value, }, + Htlcs: testHtlcs, + FinalCltvDelta: 50, + Expiry: 4000, } i.Memo = []byte("memo") i.Receipt = []byte("receipt") @@ -59,6 +82,7 @@ func TestInvoiceWorkflow(t *testing.T) { // Use single second precision to avoid false positive test // failures due to the monotonic time component. CreationDate: time.Unix(time.Now().Unix(), 0), + Htlcs: testHtlcs, } fakeInvoice.Memo = []byte("memo") fakeInvoice.Receipt = []byte("receipt") diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 339761df..18dbf75f 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -12,6 +12,7 @@ import ( "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -92,6 +93,17 @@ const ( // TODO(halseth): determine the max length payment request when field // lengths are final. MaxPaymentRequestSize = 4096 + + // A set of tlv type definitions used to serialize invoice htlcs to the + // database. + chanIDType tlv.Type = 1 + htlcIDType tlv.Type = 3 + amtType tlv.Type = 5 + acceptHeightType tlv.Type = 7 + acceptTimeType tlv.Type = 9 + resolveTimeType tlv.Type = 11 + expiryHeightType tlv.Type = 13 + stateType tlv.Type = 15 ) // ContractState describes the state the invoice is in. @@ -172,6 +184,13 @@ type Invoice struct { // for this invoice can be stored. PaymentRequest []byte + // FinalCltvDelta is the minimum required number of blocks before htlc + // expiry when the invoice is accepted. + FinalCltvDelta int32 + + // Expiry defines how long after creation this invoice should expire. + Expiry time.Duration + // CreationDate is the exact time the invoice was created. CreationDate time.Time @@ -209,6 +228,52 @@ type Invoice struct { // that the invoice originally didn't specify an amount, or the sender // overpaid. AmtPaid lnwire.MilliSatoshi + + // Htlcs records all htlcs that paid to this invoice. Some of these + // htlcs may have been marked as cancelled. + Htlcs map[CircuitKey]*InvoiceHTLC +} + +// HtlcState defines the states an htlc paying to an invoice can be in. +type HtlcState uint8 + +const ( + // HtlcStateAccepted indicates the htlc is locked-in, but not resolved. + HtlcStateAccepted HtlcState = iota + + // HtlcStateCancelled indicates the htlc is cancelled back to the + // sender. + HtlcStateCancelled + + // HtlcStateSettled indicates the htlc is settled. + HtlcStateSettled +) + +// InvoiceHTLC contains details about an htlc paying to this invoice. +type InvoiceHTLC struct { + // Amt is the amount that is carried by this htlc. + Amt lnwire.MilliSatoshi + + // AcceptHeight is the block height at which the invoice registry + // decided to accept this htlc as a payment to the invoice. At this + // height, the invoice cltv delay must have been met. + AcceptHeight uint32 + + // AcceptTime is the wall clock time at which the invoice registry + // decided to accept the htlc. + AcceptTime time.Time + + // ResolveTime is the wall clock time at which the invoice registry + // decided to settle the htlc. + ResolveTime time.Time + + // Expiry is the expiry height of this htlc. + Expiry uint32 + + // State indicates the state the invoice htlc is currently in. A + // cancelled htlc isn't just removed from the invoice htlcs map, because + // we need AcceptedHeight to properly cancel the htlc back. + State HtlcState } func validateInvoice(i *Invoice) error { @@ -865,6 +930,11 @@ func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket, return nextAddSeqNo, nil } +// serializeInvoice serializes an invoice to a writer. +// +// Note: this function is in use for a migration. Before making changes that +// would modify the on disk format, make a copy of the original code and store +// it with the migration. func serializeInvoice(w io.Writer, i *Invoice) error { if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil { return err @@ -876,6 +946,14 @@ func serializeInvoice(w io.Writer, i *Invoice) error { return err } + if err := binary.Write(w, byteOrder, i.FinalCltvDelta); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, int64(i.Expiry)); err != nil { + return err + } + birthBytes, err := i.CreationDate.MarshalBinary() if err != nil { return err @@ -918,6 +996,57 @@ func serializeInvoice(w io.Writer, i *Invoice) error { return err } + if err := serializeHtlcs(w, i.Htlcs); err != nil { + return err + } + + return nil +} + +// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to +// a writer. +func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error { + for key, htlc := range htlcs { + // Encode the htlc in a tlv stream. + chanID := key.ChanID.ToUint64() + amt := uint64(htlc.Amt) + acceptTime := uint64(htlc.AcceptTime.UnixNano()) + resolveTime := uint64(htlc.ResolveTime.UnixNano()) + state := uint8(htlc.State) + + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(chanIDType, &chanID), + tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID), + tlv.MakePrimitiveRecord(amtType, &amt), + tlv.MakePrimitiveRecord( + acceptHeightType, &htlc.AcceptHeight, + ), + tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime), + tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime), + tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), + tlv.MakePrimitiveRecord(stateType, &state), + ) + if err != nil { + return err + } + + var b bytes.Buffer + if err := tlvStream.Encode(&b); err != nil { + return err + } + + // Write the length of the tlv stream followed by the stream + // bytes. + err = binary.Write(w, byteOrder, uint64(b.Len())) + if err != nil { + return err + } + + if _, err := w.Write(b.Bytes()); err != nil { + return err + } + } + return nil } @@ -951,6 +1080,16 @@ func deserializeInvoice(r io.Reader) (Invoice, error) { return invoice, err } + if err := binary.Read(r, byteOrder, &invoice.FinalCltvDelta); err != nil { + return invoice, err + } + + var expiry int64 + if err := binary.Read(r, byteOrder, &expiry); err != nil { + return invoice, err + } + invoice.Expiry = time.Duration(expiry) + birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth") if err != nil { return invoice, err @@ -990,9 +1129,77 @@ func deserializeInvoice(r io.Reader) (Invoice, error) { return invoice, err } + invoice.Htlcs, err = deserializeHtlcs(r) + if err != nil { + return Invoice{}, err + } + return invoice, nil } +// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it +// as a map. +func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) { + htlcs := make(map[CircuitKey]*InvoiceHTLC, 0) + + for { + // Read the length of the tlv stream for this htlc. + var streamLen uint64 + if err := binary.Read(r, byteOrder, &streamLen); err != nil { + if err == io.EOF { + break + } + + return nil, err + } + + streamBytes := make([]byte, streamLen) + if _, err := r.Read(streamBytes); err != nil { + return nil, err + } + streamReader := bytes.NewReader(streamBytes) + + // Decode the contents into the htlc fields. + var ( + htlc InvoiceHTLC + key CircuitKey + chanID uint64 + state uint8 + acceptTime, resolveTime uint64 + amt uint64 + ) + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(chanIDType, &chanID), + tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID), + tlv.MakePrimitiveRecord(amtType, &amt), + tlv.MakePrimitiveRecord( + acceptHeightType, &htlc.AcceptHeight, + ), + tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime), + tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime), + tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), + tlv.MakePrimitiveRecord(stateType, &state), + ) + if err != nil { + return nil, err + } + + if err := tlvStream.Decode(streamReader); err != nil { + return nil, err + } + + key.ChanID = lnwire.NewShortChanIDFromInt(chanID) + htlc.AcceptTime = time.Unix(0, int64(acceptTime)) + htlc.ResolveTime = time.Unix(0, int64(resolveTime)) + htlc.State = HtlcState(state) + htlc.Amt = lnwire.MilliSatoshi(amt) + + htlcs[key] = &htlc + } + + return htlcs, nil +} + func acceptOrSettleInvoice(invoices, settleIndex *bbolt.Bucket, invoiceNum []byte, amtPaid lnwire.MilliSatoshi, checkHtlcParameters func(invoice *Invoice) error) ( diff --git a/channeldb/migration_09_legacy_serialization.go b/channeldb/migration_09_legacy_serialization.go index 52e765ed..56e36ab1 100644 --- a/channeldb/migration_09_legacy_serialization.go +++ b/channeldb/migration_09_legacy_serialization.go @@ -177,7 +177,7 @@ func fetchPaymentStatusTx(tx *bbolt.Tx, paymentHash [32]byte) (PaymentStatus, er func serializeOutgoingPayment(w io.Writer, p *outgoingPayment) error { var scratch [8]byte - if err := serializeInvoice(w, &p.Invoice); err != nil { + if err := serializeInvoiceLegacy(w, &p.Invoice); err != nil { return err } @@ -218,7 +218,7 @@ func deserializeOutgoingPayment(r io.Reader) (*outgoingPayment, error) { p := &outgoingPayment{} - inv, err := deserializeInvoice(r) + inv, err := deserializeInvoiceLegacy(r) if err != nil { return nil, err } diff --git a/channeldb/migration_11_invoices.go b/channeldb/migration_11_invoices.go new file mode 100644 index 00000000..e242309b --- /dev/null +++ b/channeldb/migration_11_invoices.go @@ -0,0 +1,225 @@ +package channeldb + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + bitcoinCfg "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/zpay32" + litecoinCfg "github.com/ltcsuite/ltcd/chaincfg" +) + +// migrateInvoices adds invoice htlcs and a separate cltv delta field to the +// invoices. +func migrateInvoices(tx *bbolt.Tx) error { + log.Infof("Migrating invoices to new invoice format") + + invoiceB := tx.Bucket(invoiceBucket) + if invoiceB == nil { + return nil + } + + // Iterate through the entire key space of the top-level invoice bucket. + // If key with a non-nil value stores the next invoice ID which maps to + // the corresponding invoice. Store those keys first, because it isn't + // safe to modify the bucket inside a ForEach loop. + var invoiceKeys [][]byte + err := invoiceB.ForEach(func(k, v []byte) error { + if v == nil { + return nil + } + + invoiceKeys = append(invoiceKeys, k) + + return nil + }) + if err != nil { + return err + } + + nets := []*bitcoinCfg.Params{ + &bitcoinCfg.MainNetParams, &bitcoinCfg.SimNetParams, + &bitcoinCfg.RegressionNetParams, &bitcoinCfg.TestNet3Params, + } + + ltcNets := []*litecoinCfg.Params{ + &litecoinCfg.MainNetParams, &litecoinCfg.SimNetParams, + &litecoinCfg.RegressionNetParams, &litecoinCfg.TestNet4Params, + } + for _, net := range ltcNets { + var convertedNet bitcoinCfg.Params + convertedNet.Bech32HRPSegwit = net.Bech32HRPSegwit + nets = append(nets, &convertedNet) + } + + // Iterate over all stored keys and migrate the invoices. + for _, k := range invoiceKeys { + v := invoiceB.Get(k) + + // Deserialize the invoice with the deserializing function that + // was in use for this version of the database. + invoiceReader := bytes.NewReader(v) + invoice, err := deserializeInvoiceLegacy(invoiceReader) + if err != nil { + return err + } + + // Try to decode the payment request for every possible net to + // avoid passing a the active network to channeldb. This would + // be a layering violation, while this migration is only running + // once and will likely be removed in the future. + var payReq *zpay32.Invoice + for _, net := range nets { + payReq, err = zpay32.Decode( + string(invoice.PaymentRequest), net, + ) + if err == nil { + break + } + } + if payReq == nil { + return fmt.Errorf("cannot decode payreq") + } + invoice.FinalCltvDelta = int32(payReq.MinFinalCLTVExpiry()) + invoice.Expiry = payReq.Expiry() + + // Serialize the invoice in the new format and use it to replace + // the old invoice in the database. + var buf bytes.Buffer + if err := serializeInvoice(&buf, &invoice); err != nil { + return err + } + + err = invoiceB.Put(k, buf.Bytes()) + if err != nil { + return err + } + } + + log.Infof("Migration of invoices completed!") + return nil +} + +func deserializeInvoiceLegacy(r io.Reader) (Invoice, error) { + var err error + invoice := Invoice{} + + // TODO(roasbeef): use read full everywhere + invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "") + if err != nil { + return invoice, err + } + invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "") + if err != nil { + return invoice, err + } + + invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "") + if err != nil { + return invoice, err + } + + birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth") + if err != nil { + return invoice, err + } + if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil { + return invoice, err + } + + settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled") + if err != nil { + return invoice, err + } + if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil { + return invoice, err + } + + if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil { + return invoice, err + } + var scratch [8]byte + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return invoice, err + } + invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + + if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil { + return invoice, err + } + + if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil { + return invoice, err + } + if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil { + return invoice, err + } + if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil { + return invoice, err + } + + return invoice, nil +} + +// serializeInvoiceLegacy serializes an invoice in the format of the previous db +// version. +func serializeInvoiceLegacy(w io.Writer, i *Invoice) error { + if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil { + return err + } + + birthBytes, err := i.CreationDate.MarshalBinary() + if err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil { + return err + } + + settleBytes, err := i.SettleDate.MarshalBinary() + if err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil { + return err + } + + if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil { + return err + } + + var scratch [8]byte + byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.Terms.State); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.AddIndex); err != nil { + return err + } + if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil { + return err + } + if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil { + return err + } + + return nil +} diff --git a/channeldb/migration_11_invoices_test.go b/channeldb/migration_11_invoices_test.go new file mode 100644 index 00000000..9739af80 --- /dev/null +++ b/channeldb/migration_11_invoices_test.go @@ -0,0 +1,166 @@ +package channeldb + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + bitcoinCfg "github.com/btcsuite/btcd/chaincfg" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/zpay32" + litecoinCfg "github.com/ltcsuite/ltcd/chaincfg" +) + +var ( + testPrivKeyBytes = []byte{ + 0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf, + 0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9, + 0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f, + 0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90, + } + + testCltvDelta = int32(50) +) + +// TestMigrateInvoices checks that invoices are migrated correctly. +func TestMigrateInvoices(t *testing.T) { + t.Parallel() + + payReqBtc, err := getPayReq(&bitcoinCfg.MainNetParams) + if err != nil { + t.Fatal(err) + } + + var ltcNetParams bitcoinCfg.Params + ltcNetParams.Bech32HRPSegwit = litecoinCfg.MainNetParams.Bech32HRPSegwit + payReqLtc, err := getPayReq(<cNetParams) + if err != nil { + t.Fatal(err) + } + + invoices := []Invoice{ + { + PaymentRequest: []byte(payReqBtc), + }, + { + PaymentRequest: []byte(payReqLtc), + }, + } + + beforeMigrationFunc := func(d *DB) { + err := d.Update(func(tx *bbolt.Tx) error { + invoicesBucket, err := tx.CreateBucketIfNotExists( + invoiceBucket, + ) + if err != nil { + return err + } + + invoiceNum := uint32(1) + for _, invoice := range invoices { + var invoiceKey [4]byte + byteOrder.PutUint32(invoiceKey[:], invoiceNum) + invoiceNum++ + + var buf bytes.Buffer + err := serializeInvoiceLegacy(&buf, &invoice) + if err != nil { + return err + } + + err = invoicesBucket.Put( + invoiceKey[:], buf.Bytes(), + ) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + // Verify that all invoices were migrated. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration 'invoices' wasn't applied") + } + + dbInvoices, err := d.FetchAllInvoices(false) + if err != nil { + t.Fatalf("unable to fetch invoices: %v", err) + } + + if len(invoices) != len(dbInvoices) { + t.Fatalf("expected %d invoices, got %d", len(invoices), + len(dbInvoices)) + } + + for _, dbInvoice := range dbInvoices { + if dbInvoice.FinalCltvDelta != testCltvDelta { + t.Fatal("incorrect final cltv delta") + } + if dbInvoice.Expiry != 3600*time.Second { + t.Fatal("incorrect expiry") + } + if len(dbInvoice.Htlcs) != 0 { + t.Fatal("expected no htlcs after migration") + } + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrateInvoices, + false) +} + +// signDigestCompact generates a test signature to be used in the generation of +// test payment requests. +func signDigestCompact(hash []byte) ([]byte, error) { + // Should the signature reference a compressed public key or not. + isCompressedKey := true + + privKey, _ := btcec.PrivKeyFromBytes(btcec.S256(), testPrivKeyBytes) + + // btcec.SignCompact returns a pubkey-recoverable signature + sig, err := btcec.SignCompact( + btcec.S256(), privKey, hash, isCompressedKey, + ) + if err != nil { + return nil, fmt.Errorf("can't sign the hash: %v", err) + } + + return sig, nil +} + +// getPayReq creates a payment request for the given net. +func getPayReq(net *bitcoinCfg.Params) (string, error) { + options := []func(*zpay32.Invoice){ + zpay32.CLTVExpiry(uint64(testCltvDelta)), + zpay32.Description("test"), + } + + payReq, err := zpay32.NewInvoice( + net, [32]byte{}, time.Unix(1, 0), options..., + ) + if err != nil { + return "", err + } + return payReq.Encode( + zpay32.MessageSigner{ + SignCompact: signDigestCompact, + }, + ) +} diff --git a/channeldb/migrations.go b/channeldb/migrations.go index d875dc8d..3423a6d7 100644 --- a/channeldb/migrations.go +++ b/channeldb/migrations.go @@ -168,7 +168,7 @@ func migrateInvoiceTimeSeries(tx *bbolt.Tx) error { invoiceBytesCopy = append(invoiceBytesCopy, padding...) invoiceReader := bytes.NewReader(invoiceBytesCopy) - invoice, err := deserializeInvoice(invoiceReader) + invoice, err := deserializeInvoiceLegacy(invoiceReader) if err != nil { return fmt.Errorf("unable to decode invoice: %v", err) } @@ -227,7 +227,7 @@ func migrateInvoiceTimeSeries(tx *bbolt.Tx) error { // We've fully migrated an invoice, so we'll now update the // invoice in-place. var b bytes.Buffer - if err := serializeInvoice(&b, &invoice); err != nil { + if err := serializeInvoiceLegacy(&b, &invoice); err != nil { return err } diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index c79b60b8..055b564c 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -394,6 +394,8 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, Memo: []byte(invoice.Memo), Receipt: invoice.Receipt, PaymentRequest: []byte(payReqString), + FinalCltvDelta: int32(payReq.MinFinalCLTVExpiry()), + Expiry: payReq.Expiry(), Terms: channeldb.ContractTerm{ Value: amtMSat, PaymentPreimage: paymentPreimage,