From 37258c414c58a53ff72f2433d4f77d786d269791 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Tue, 19 Nov 2019 12:32:56 +0100 Subject: [PATCH] hop: store custom records from payload --- htlcswitch/hop/payload.go | 33 ++++++++++++++++++++++++++++++++- htlcswitch/hop/payload_test.go | 24 +++++++++++++++++++----- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index afc31308..523ffbad 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -79,6 +79,9 @@ func (e ErrInvalidPayload) Error() string { hopType, e.Violation, e.Type) } +// CustomRecordSet stores a set of custom key/value pairs. +type CustomRecordSet map[uint64][]byte + // Payload encapsulates all information delivered to a hop in an onion payload. // A Hop can represent either a TLV or legacy payload. The primary forwarding // instruction can be accessed via ForwardingInfo, and additional records can be @@ -91,6 +94,10 @@ type Payload struct { // MPP holds the info provided in an option_mpp record when parsed from // a TLV onion payload. MPP *record.MPP + + // customRecords are user-defined records in the custom type range that + // were included in the payload. + customRecords CustomRecordSet } // NewLegacyPayload builds a Payload from the amount, cltv, and next hop @@ -105,6 +112,7 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload { AmountToForward: lnwire.MilliSatoshi(f.ForwardAmount), OutgoingCTLV: f.OutgoingCltv, }, + customRecords: make(CustomRecordSet), } } @@ -157,6 +165,9 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { mpp = nil } + // Filter out the custom records. + customRecords := NewCustomRecords(parsedTypes) + return &Payload{ FwdInfo: ForwardingInfo{ Network: BitcoinNetwork, @@ -164,7 +175,8 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { AmountToForward: lnwire.MilliSatoshi(amt), OutgoingCTLV: cltv, }, - MPP: mpp, + MPP: mpp, + customRecords: customRecords, }, nil } @@ -174,6 +186,19 @@ func (h *Payload) ForwardingInfo() ForwardingInfo { return h.FwdInfo } +// NewCustomRecords filters the types parsed from the tlv stream for custom +// records. +func NewCustomRecords(parsedTypes tlv.TypeMap) CustomRecordSet { + customRecords := make(CustomRecordSet) + for t, parseResult := range parsedTypes { + if parseResult == nil || t < CustomTypeStart { + continue + } + customRecords[uint64(t)] = parseResult + } + return customRecords +} + // ValidateParsedPayloadTypes checks the types parsed from a hop payload to // ensure that the proper fields are either included or omitted. The finalHop // boolean should be true if the payload was parsed for an exit hop. The @@ -234,6 +259,12 @@ func (h *Payload) MultiPath() *record.MPP { return h.MPP } +// CustomRecords returns the custom tlv type records that were parsed from the +// payload. +func (h *Payload) CustomRecords() CustomRecordSet { + return h.customRecords +} + // getMinRequiredViolation checks for unrecognized required (even) fields in the // standard range and returns the lowest required type. Always returning the // lowest required type allows a failure message to be deterministic. diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index 4092ff48..f8c0df21 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -11,10 +11,11 @@ import ( ) type decodePayloadTest struct { - name string - payload []byte - expErr error - shouldHaveMPP bool + name string + payload []byte + expErr error + expCustomRecords map[uint64][]byte + shouldHaveMPP bool } var decodePayloadTests = []decodePayloadTest{ @@ -133,7 +134,10 @@ var decodePayloadTests = []decodePayloadTest{ { name: "required type in custom range", payload: []byte{0x02, 0x00, 0x04, 0x00, - 0xfe, 0x00, 0x01, 0x00, 0x00, 0x00, + 0xfe, 0x00, 0x01, 0x00, 0x00, 0x02, 0x10, 0x11, + }, + expCustomRecords: map[uint64][]byte{ + 65536: {0x10, 0x11}, }, }, { @@ -237,4 +241,14 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { } else if p.MPP != nil { t.Fatalf("unexpected MPP payload") } + + // Convert expected nil map to empty map, because we always expect an + // initiated map from the payload. + expCustomRecords := make(hop.CustomRecordSet) + if test.expCustomRecords != nil { + expCustomRecords = test.expCustomRecords + } + if !reflect.DeepEqual(expCustomRecords, p.CustomRecords()) { + t.Fatalf("invalid custom records") + } }