hop: store custom records from payload
This commit is contained in:
parent
cbe213fd0c
commit
37258c414c
@ -79,6 +79,9 @@ func (e ErrInvalidPayload) Error() string {
|
||||
hopType, e.Violation, e.Type)
|
||||
}
|
||||
|
||||
// CustomRecordSet stores a set of custom key/value pairs.
|
||||
type CustomRecordSet map[uint64][]byte
|
||||
|
||||
// Payload encapsulates all information delivered to a hop in an onion payload.
|
||||
// A Hop can represent either a TLV or legacy payload. The primary forwarding
|
||||
// instruction can be accessed via ForwardingInfo, and additional records can be
|
||||
@ -91,6 +94,10 @@ type Payload struct {
|
||||
// MPP holds the info provided in an option_mpp record when parsed from
|
||||
// a TLV onion payload.
|
||||
MPP *record.MPP
|
||||
|
||||
// customRecords are user-defined records in the custom type range that
|
||||
// were included in the payload.
|
||||
customRecords CustomRecordSet
|
||||
}
|
||||
|
||||
// NewLegacyPayload builds a Payload from the amount, cltv, and next hop
|
||||
@ -105,6 +112,7 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload {
|
||||
AmountToForward: lnwire.MilliSatoshi(f.ForwardAmount),
|
||||
OutgoingCTLV: f.OutgoingCltv,
|
||||
},
|
||||
customRecords: make(CustomRecordSet),
|
||||
}
|
||||
}
|
||||
|
||||
@ -157,6 +165,9 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
mpp = nil
|
||||
}
|
||||
|
||||
// Filter out the custom records.
|
||||
customRecords := NewCustomRecords(parsedTypes)
|
||||
|
||||
return &Payload{
|
||||
FwdInfo: ForwardingInfo{
|
||||
Network: BitcoinNetwork,
|
||||
@ -164,7 +175,8 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
|
||||
AmountToForward: lnwire.MilliSatoshi(amt),
|
||||
OutgoingCTLV: cltv,
|
||||
},
|
||||
MPP: mpp,
|
||||
MPP: mpp,
|
||||
customRecords: customRecords,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -174,6 +186,19 @@ func (h *Payload) ForwardingInfo() ForwardingInfo {
|
||||
return h.FwdInfo
|
||||
}
|
||||
|
||||
// NewCustomRecords filters the types parsed from the tlv stream for custom
|
||||
// records.
|
||||
func NewCustomRecords(parsedTypes tlv.TypeMap) CustomRecordSet {
|
||||
customRecords := make(CustomRecordSet)
|
||||
for t, parseResult := range parsedTypes {
|
||||
if parseResult == nil || t < CustomTypeStart {
|
||||
continue
|
||||
}
|
||||
customRecords[uint64(t)] = parseResult
|
||||
}
|
||||
return customRecords
|
||||
}
|
||||
|
||||
// ValidateParsedPayloadTypes checks the types parsed from a hop payload to
|
||||
// ensure that the proper fields are either included or omitted. The finalHop
|
||||
// boolean should be true if the payload was parsed for an exit hop. The
|
||||
@ -234,6 +259,12 @@ func (h *Payload) MultiPath() *record.MPP {
|
||||
return h.MPP
|
||||
}
|
||||
|
||||
// CustomRecords returns the custom tlv type records that were parsed from the
|
||||
// payload.
|
||||
func (h *Payload) CustomRecords() CustomRecordSet {
|
||||
return h.customRecords
|
||||
}
|
||||
|
||||
// getMinRequiredViolation checks for unrecognized required (even) fields in the
|
||||
// standard range and returns the lowest required type. Always returning the
|
||||
// lowest required type allows a failure message to be deterministic.
|
||||
|
@ -11,10 +11,11 @@ import (
|
||||
)
|
||||
|
||||
type decodePayloadTest struct {
|
||||
name string
|
||||
payload []byte
|
||||
expErr error
|
||||
shouldHaveMPP bool
|
||||
name string
|
||||
payload []byte
|
||||
expErr error
|
||||
expCustomRecords map[uint64][]byte
|
||||
shouldHaveMPP bool
|
||||
}
|
||||
|
||||
var decodePayloadTests = []decodePayloadTest{
|
||||
@ -133,7 +134,10 @@ var decodePayloadTests = []decodePayloadTest{
|
||||
{
|
||||
name: "required type in custom range",
|
||||
payload: []byte{0x02, 0x00, 0x04, 0x00,
|
||||
0xfe, 0x00, 0x01, 0x00, 0x00, 0x00,
|
||||
0xfe, 0x00, 0x01, 0x00, 0x00, 0x02, 0x10, 0x11,
|
||||
},
|
||||
expCustomRecords: map[uint64][]byte{
|
||||
65536: {0x10, 0x11},
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -237,4 +241,14 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
|
||||
} else if p.MPP != nil {
|
||||
t.Fatalf("unexpected MPP payload")
|
||||
}
|
||||
|
||||
// Convert expected nil map to empty map, because we always expect an
|
||||
// initiated map from the payload.
|
||||
expCustomRecords := make(hop.CustomRecordSet)
|
||||
if test.expCustomRecords != nil {
|
||||
expCustomRecords = test.expCustomRecords
|
||||
}
|
||||
if !reflect.DeepEqual(expCustomRecords, p.CustomRecords()) {
|
||||
t.Fatalf("invalid custom records")
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user