record: move CustomRecordSet

This commit is contained in:
Joost Jager 2019-12-12 00:01:55 +01:00
parent 7aa4a7c7fc
commit 8b5bb0ac63
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
9 changed files with 34 additions and 35 deletions

@ -7,9 +7,9 @@ import (
"time" "time"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
) )
var ( var (
@ -212,7 +212,7 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) {
key := CircuitKey{ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 4} key := CircuitKey{ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 4}
htlc := HtlcAcceptDesc{ htlc := HtlcAcceptDesc{
Amt: 500, Amt: 500,
CustomRecords: make(hop.CustomRecordSet), CustomRecords: make(record.CustomSet),
} }
invoice, err := db.UpdateInvoice(paymentHash, invoice, err := db.UpdateInvoice(paymentHash,
func(invoice *Invoice) (*InvoiceUpdateDesc, error) { func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
@ -439,7 +439,7 @@ func TestDuplicateSettleInvoice(t *testing.T) {
AcceptTime: time.Unix(1, 0), AcceptTime: time.Unix(1, 0),
ResolveTime: time.Unix(1, 0), ResolveTime: time.Unix(1, 0),
State: HtlcStateSettled, State: HtlcStateSettled,
CustomRecords: make(hop.CustomRecordSet), CustomRecords: make(record.CustomSet),
}, },
} }
@ -751,7 +751,7 @@ func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback {
return nil, ErrInvoiceAlreadySettled return nil, ErrInvoiceAlreadySettled
} }
noRecords := make(hop.CustomRecordSet) noRecords := make(record.CustomSet)
update := &InvoiceUpdateDesc{ update := &InvoiceUpdateDesc{
State: &InvoiceStateUpdateDesc{ State: &InvoiceStateUpdateDesc{
@ -795,7 +795,7 @@ func TestCustomRecords(t *testing.T) {
// Accept an htlc with custom records on this invoice. // Accept an htlc with custom records on this invoice.
key := CircuitKey{ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 4} key := CircuitKey{ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 4}
records := hop.CustomRecordSet{ records := record.CustomSet{
100000: []byte{}, 100000: []byte{},
100001: []byte{1, 2}, 100001: []byte{1, 2},
} }

@ -12,6 +12,7 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
) )
@ -317,7 +318,7 @@ type InvoiceHTLC struct {
// CustomRecords contains the custom key/value pairs that accompanied // CustomRecords contains the custom key/value pairs that accompanied
// the htlc. // the htlc.
CustomRecords hop.CustomRecordSet CustomRecords record.CustomSet
} }
// HtlcAcceptDesc describes the details of a newly accepted htlc. // HtlcAcceptDesc describes the details of a newly accepted htlc.
@ -337,7 +338,7 @@ type HtlcAcceptDesc struct {
// CustomRecords contains the custom key/value pairs that accompanied // CustomRecords contains the custom key/value pairs that accompanied
// the htlc. // the htlc.
CustomRecords hop.CustomRecordSet CustomRecords record.CustomSet
} }
// InvoiceUpdateDesc describes the changes that should be applied to the // InvoiceUpdateDesc describes the changes that should be applied to the

@ -29,12 +29,6 @@ const (
RequiredViolation RequiredViolation
) )
const (
// CustomTypeStart is the start of the custom tlv type range as defined
// in BOLT 01.
CustomTypeStart = 65536
)
// String returns a human-readable description of the violation as a verb. // String returns a human-readable description of the violation as a verb.
func (v PayloadViolation) String() string { func (v PayloadViolation) String() string {
switch v { switch v {
@ -79,9 +73,6 @@ func (e ErrInvalidPayload) Error() string {
hopType, e.Violation, e.Type) hopType, e.Violation, e.Type)
} }
// CustomRecordSet stores a set of custom key/value pairs.
type CustomRecordSet map[uint64][]byte
// Payload encapsulates all information delivered to a hop in an onion payload. // Payload encapsulates all information delivered to a hop in an onion payload.
// A Hop can represent either a TLV or legacy payload. The primary forwarding // A Hop can represent either a TLV or legacy payload. The primary forwarding
// instruction can be accessed via ForwardingInfo, and additional records can be // instruction can be accessed via ForwardingInfo, and additional records can be
@ -97,7 +88,7 @@ type Payload struct {
// customRecords are user-defined records in the custom type range that // customRecords are user-defined records in the custom type range that
// were included in the payload. // were included in the payload.
customRecords CustomRecordSet customRecords record.CustomSet
} }
// NewLegacyPayload builds a Payload from the amount, cltv, and next hop // NewLegacyPayload builds a Payload from the amount, cltv, and next hop
@ -112,7 +103,7 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload {
AmountToForward: lnwire.MilliSatoshi(f.ForwardAmount), AmountToForward: lnwire.MilliSatoshi(f.ForwardAmount),
OutgoingCTLV: f.OutgoingCltv, OutgoingCTLV: f.OutgoingCltv,
}, },
customRecords: make(CustomRecordSet), customRecords: make(record.CustomSet),
} }
} }
@ -188,10 +179,10 @@ func (h *Payload) ForwardingInfo() ForwardingInfo {
// NewCustomRecords filters the types parsed from the tlv stream for custom // NewCustomRecords filters the types parsed from the tlv stream for custom
// records. // records.
func NewCustomRecords(parsedTypes tlv.TypeMap) CustomRecordSet { func NewCustomRecords(parsedTypes tlv.TypeMap) record.CustomSet {
customRecords := make(CustomRecordSet) customRecords := make(record.CustomSet)
for t, parseResult := range parsedTypes { for t, parseResult := range parsedTypes {
if parseResult == nil || t < CustomTypeStart { if parseResult == nil || t < record.CustomTypeStart {
continue continue
} }
customRecords[uint64(t)] = parseResult customRecords[uint64(t)] = parseResult
@ -261,7 +252,7 @@ func (h *Payload) MultiPath() *record.MPP {
// CustomRecords returns the custom tlv type records that were parsed from the // CustomRecords returns the custom tlv type records that were parsed from the
// payload. // payload.
func (h *Payload) CustomRecords() CustomRecordSet { func (h *Payload) CustomRecords() record.CustomSet {
return h.customRecords return h.customRecords
} }
@ -280,7 +271,9 @@ func getMinRequiredViolation(set tlv.TypeMap) *tlv.Type {
// //
// We always accept custom fields, because a higher level // We always accept custom fields, because a higher level
// application may understand them. // application may understand them.
if parseResult == nil || t%2 != 0 || t >= CustomTypeStart { if parseResult == nil || t%2 != 0 ||
t >= record.CustomTypeStart {
continue continue
} }

@ -244,7 +244,7 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
// Convert expected nil map to empty map, because we always expect an // Convert expected nil map to empty map, because we always expect an
// initiated map from the payload. // initiated map from the payload.
expCustomRecords := make(hop.CustomRecordSet) expCustomRecords := make(record.CustomSet)
if test.expCustomRecords != nil { if test.expCustomRecords != nil {
expCustomRecords = test.expCustomRecords expCustomRecords = test.expCustomRecords
} }

@ -1,7 +1,6 @@
package invoices package invoices
import ( import (
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
) )
@ -14,5 +13,5 @@ type Payload interface {
// CustomRecords returns the custom tlv type records that were parsed // CustomRecords returns the custom tlv type records that were parsed
// from the payload. // from the payload.
CustomRecords() hop.CustomRecordSet CustomRecords() record.CustomSet
} }

@ -7,7 +7,6 @@ import (
"time" "time"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
@ -690,8 +689,8 @@ func (p *mockPayload) MultiPath() *record.MPP {
return p.mpp return p.mpp
} }
func (p *mockPayload) CustomRecords() hop.CustomRecordSet { func (p *mockPayload) CustomRecords() record.CustomSet {
return make(hop.CustomRecordSet) return make(record.CustomSet)
} }
// TestSettleMpp tests settling of an invoice with multiple partial payments. // TestSettleMpp tests settling of an invoice with multiple partial payments.

@ -3,8 +3,6 @@ package invoices
import ( import (
"errors" "errors"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
@ -105,7 +103,7 @@ type invoiceUpdateCtx struct {
expiry uint32 expiry uint32
currentHeight int32 currentHeight int32
finalCltvRejectDelta int32 finalCltvRejectDelta int32
customRecords hop.CustomRecordSet customRecords record.CustomSet
mpp *record.MPP mpp *record.MPP
} }

@ -13,7 +13,6 @@ import (
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -385,9 +384,9 @@ func UnmarshallCustomRecords(rpcRecords map[uint64][]byte) ([]tlv.Record,
// tlvRecords is sorted, so we only need to check that the first // tlvRecords is sorted, so we only need to check that the first
// element is within the custom range. // element is within the custom range.
if uint64(tlvRecords[0].Type()) < hop.CustomTypeStart { if uint64(tlvRecords[0].Type()) < record.CustomTypeStart {
return nil, fmt.Errorf("no custom records with types "+ return nil, fmt.Errorf("no custom records with types "+
"below %v allowed", hop.CustomTypeStart) "below %v allowed", record.CustomTypeStart)
} }
return tlvRecords, nil return tlvRecords, nil

10
record/custom_records.go Normal file

@ -0,0 +1,10 @@
package record
const (
// CustomTypeStart is the start of the custom tlv type range as defined
// in BOLT 01.
CustomTypeStart = 65536
)
// CustomSet stores a set of custom key/value pairs.
type CustomSet map[uint64][]byte