diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index 24eaf180..b38c5cad 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -7,9 +7,9 @@ import ( "time" "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" ) var ( @@ -212,7 +212,7 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) { key := CircuitKey{ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 4} htlc := HtlcAcceptDesc{ Amt: 500, - CustomRecords: make(hop.CustomRecordSet), + CustomRecords: make(record.CustomSet), } invoice, err := db.UpdateInvoice(paymentHash, func(invoice *Invoice) (*InvoiceUpdateDesc, error) { @@ -439,7 +439,7 @@ func TestDuplicateSettleInvoice(t *testing.T) { AcceptTime: time.Unix(1, 0), ResolveTime: time.Unix(1, 0), State: HtlcStateSettled, - CustomRecords: make(hop.CustomRecordSet), + CustomRecords: make(record.CustomSet), }, } @@ -751,7 +751,7 @@ func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback { return nil, ErrInvoiceAlreadySettled } - noRecords := make(hop.CustomRecordSet) + noRecords := make(record.CustomSet) update := &InvoiceUpdateDesc{ State: &InvoiceStateUpdateDesc{ @@ -795,7 +795,7 @@ func TestCustomRecords(t *testing.T) { // Accept an htlc with custom records on this invoice. key := CircuitKey{ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 4} - records := hop.CustomRecordSet{ + records := record.CustomSet{ 100000: []byte{}, 100001: []byte{1, 2}, } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 106c0e41..8ded522f 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/tlv" ) @@ -317,7 +318,7 @@ type InvoiceHTLC struct { // CustomRecords contains the custom key/value pairs that accompanied // the htlc. - CustomRecords hop.CustomRecordSet + CustomRecords record.CustomSet } // HtlcAcceptDesc describes the details of a newly accepted htlc. @@ -337,7 +338,7 @@ type HtlcAcceptDesc struct { // CustomRecords contains the custom key/value pairs that accompanied // the htlc. - CustomRecords hop.CustomRecordSet + CustomRecords record.CustomSet } // InvoiceUpdateDesc describes the changes that should be applied to the diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 523ffbad..d8aafd9c 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -29,12 +29,6 @@ const ( RequiredViolation ) -const ( - // CustomTypeStart is the start of the custom tlv type range as defined - // in BOLT 01. - CustomTypeStart = 65536 -) - // String returns a human-readable description of the violation as a verb. func (v PayloadViolation) String() string { switch v { @@ -79,9 +73,6 @@ func (e ErrInvalidPayload) Error() string { hopType, e.Violation, e.Type) } -// CustomRecordSet stores a set of custom key/value pairs. -type CustomRecordSet map[uint64][]byte - // Payload encapsulates all information delivered to a hop in an onion payload. // A Hop can represent either a TLV or legacy payload. The primary forwarding // instruction can be accessed via ForwardingInfo, and additional records can be @@ -97,7 +88,7 @@ type Payload struct { // customRecords are user-defined records in the custom type range that // were included in the payload. - customRecords CustomRecordSet + customRecords record.CustomSet } // NewLegacyPayload builds a Payload from the amount, cltv, and next hop @@ -112,7 +103,7 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload { AmountToForward: lnwire.MilliSatoshi(f.ForwardAmount), OutgoingCTLV: f.OutgoingCltv, }, - customRecords: make(CustomRecordSet), + customRecords: make(record.CustomSet), } } @@ -188,10 +179,10 @@ func (h *Payload) ForwardingInfo() ForwardingInfo { // NewCustomRecords filters the types parsed from the tlv stream for custom // records. -func NewCustomRecords(parsedTypes tlv.TypeMap) CustomRecordSet { - customRecords := make(CustomRecordSet) +func NewCustomRecords(parsedTypes tlv.TypeMap) record.CustomSet { + customRecords := make(record.CustomSet) for t, parseResult := range parsedTypes { - if parseResult == nil || t < CustomTypeStart { + if parseResult == nil || t < record.CustomTypeStart { continue } customRecords[uint64(t)] = parseResult @@ -261,7 +252,7 @@ func (h *Payload) MultiPath() *record.MPP { // CustomRecords returns the custom tlv type records that were parsed from the // payload. -func (h *Payload) CustomRecords() CustomRecordSet { +func (h *Payload) CustomRecords() record.CustomSet { return h.customRecords } @@ -280,7 +271,9 @@ func getMinRequiredViolation(set tlv.TypeMap) *tlv.Type { // // We always accept custom fields, because a higher level // application may understand them. - if parseResult == nil || t%2 != 0 || t >= CustomTypeStart { + if parseResult == nil || t%2 != 0 || + t >= record.CustomTypeStart { + continue } diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index f8c0df21..b0a92534 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -244,7 +244,7 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { // Convert expected nil map to empty map, because we always expect an // initiated map from the payload. - expCustomRecords := make(hop.CustomRecordSet) + expCustomRecords := make(record.CustomSet) if test.expCustomRecords != nil { expCustomRecords = test.expCustomRecords } diff --git a/invoices/interface.go b/invoices/interface.go index bf7f0ed9..c511dded 100644 --- a/invoices/interface.go +++ b/invoices/interface.go @@ -1,7 +1,6 @@ package invoices import ( - "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/record" ) @@ -14,5 +13,5 @@ type Payload interface { // CustomRecords returns the custom tlv type records that were parsed // from the payload. - CustomRecords() hop.CustomRecordSet + CustomRecords() record.CustomSet } diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index 8b992258..eabb2ba2 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -7,7 +7,6 @@ import ( "time" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -690,8 +689,8 @@ func (p *mockPayload) MultiPath() *record.MPP { return p.mpp } -func (p *mockPayload) CustomRecords() hop.CustomRecordSet { - return make(hop.CustomRecordSet) +func (p *mockPayload) CustomRecords() record.CustomSet { + return make(record.CustomSet) } // TestSettleMpp tests settling of an invoice with multiple partial payments. diff --git a/invoices/update.go b/invoices/update.go index 3175efa6..913caeb0 100644 --- a/invoices/update.go +++ b/invoices/update.go @@ -3,8 +3,6 @@ package invoices import ( "errors" - "github.com/lightningnetwork/lnd/htlcswitch/hop" - "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -105,7 +103,7 @@ type invoiceUpdateCtx struct { expiry uint32 currentHeight int32 finalCltvRejectDelta int32 - customRecords hop.CustomRecordSet + customRecords record.CustomSet mpp *record.MPP } diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 38c4a1b7..b9de9617 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -13,7 +13,6 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" @@ -385,9 +384,9 @@ func UnmarshallCustomRecords(rpcRecords map[uint64][]byte) ([]tlv.Record, // tlvRecords is sorted, so we only need to check that the first // element is within the custom range. - if uint64(tlvRecords[0].Type()) < hop.CustomTypeStart { + if uint64(tlvRecords[0].Type()) < record.CustomTypeStart { return nil, fmt.Errorf("no custom records with types "+ - "below %v allowed", hop.CustomTypeStart) + "below %v allowed", record.CustomTypeStart) } return tlvRecords, nil diff --git a/record/custom_records.go b/record/custom_records.go new file mode 100644 index 00000000..36e9e5ac --- /dev/null +++ b/record/custom_records.go @@ -0,0 +1,10 @@ +package record + +const ( + // CustomTypeStart is the start of the custom tlv type range as defined + // in BOLT 01. + CustomTypeStart = 65536 +) + +// CustomSet stores a set of custom key/value pairs. +type CustomSet map[uint64][]byte