hop: store custom records from payload

This commit is contained in:
Joost Jager 2019-11-19 12:32:56 +01:00
parent cbe213fd0c
commit 37258c414c
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
2 changed files with 51 additions and 6 deletions

@ -79,6 +79,9 @@ func (e ErrInvalidPayload) Error() string {
hopType, e.Violation, e.Type)
}
// CustomRecordSet stores a set of custom key/value pairs.
type CustomRecordSet map[uint64][]byte
// Payload encapsulates all information delivered to a hop in an onion payload.
// A Hop can represent either a TLV or legacy payload. The primary forwarding
// instruction can be accessed via ForwardingInfo, and additional records can be
@ -91,6 +94,10 @@ type Payload struct {
// MPP holds the info provided in an option_mpp record when parsed from
// a TLV onion payload.
MPP *record.MPP
// customRecords are user-defined records in the custom type range that
// were included in the payload.
customRecords CustomRecordSet
}
// NewLegacyPayload builds a Payload from the amount, cltv, and next hop
@ -105,6 +112,7 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload {
AmountToForward: lnwire.MilliSatoshi(f.ForwardAmount),
OutgoingCTLV: f.OutgoingCltv,
},
customRecords: make(CustomRecordSet),
}
}
@ -157,6 +165,9 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
mpp = nil
}
// Filter out the custom records.
customRecords := NewCustomRecords(parsedTypes)
return &Payload{
FwdInfo: ForwardingInfo{
Network: BitcoinNetwork,
@ -164,7 +175,8 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
AmountToForward: lnwire.MilliSatoshi(amt),
OutgoingCTLV: cltv,
},
MPP: mpp,
MPP: mpp,
customRecords: customRecords,
}, nil
}
@ -174,6 +186,19 @@ func (h *Payload) ForwardingInfo() ForwardingInfo {
return h.FwdInfo
}
// NewCustomRecords filters the types parsed from the tlv stream for custom
// records.
func NewCustomRecords(parsedTypes tlv.TypeMap) CustomRecordSet {
customRecords := make(CustomRecordSet)
for t, parseResult := range parsedTypes {
if parseResult == nil || t < CustomTypeStart {
continue
}
customRecords[uint64(t)] = parseResult
}
return customRecords
}
// ValidateParsedPayloadTypes checks the types parsed from a hop payload to
// ensure that the proper fields are either included or omitted. The finalHop
// boolean should be true if the payload was parsed for an exit hop. The
@ -234,6 +259,12 @@ func (h *Payload) MultiPath() *record.MPP {
return h.MPP
}
// CustomRecords returns the custom tlv type records that were parsed from the
// payload.
func (h *Payload) CustomRecords() CustomRecordSet {
return h.customRecords
}
// getMinRequiredViolation checks for unrecognized required (even) fields in the
// standard range and returns the lowest required type. Always returning the
// lowest required type allows a failure message to be deterministic.

@ -11,10 +11,11 @@ import (
)
type decodePayloadTest struct {
name string
payload []byte
expErr error
shouldHaveMPP bool
name string
payload []byte
expErr error
expCustomRecords map[uint64][]byte
shouldHaveMPP bool
}
var decodePayloadTests = []decodePayloadTest{
@ -133,7 +134,10 @@ var decodePayloadTests = []decodePayloadTest{
{
name: "required type in custom range",
payload: []byte{0x02, 0x00, 0x04, 0x00,
0xfe, 0x00, 0x01, 0x00, 0x00, 0x00,
0xfe, 0x00, 0x01, 0x00, 0x00, 0x02, 0x10, 0x11,
},
expCustomRecords: map[uint64][]byte{
65536: {0x10, 0x11},
},
},
{
@ -237,4 +241,14 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
} else if p.MPP != nil {
t.Fatalf("unexpected MPP payload")
}
// Convert expected nil map to empty map, because we always expect an
// initiated map from the payload.
expCustomRecords := make(hop.CustomRecordSet)
if test.expCustomRecords != nil {
expCustomRecords = test.expCustomRecords
}
if !reflect.DeepEqual(expCustomRecords, p.CustomRecords()) {
t.Fatalf("invalid custom records")
}
}