htlcswitch/iterator: validate presence/omission of payload types

From BOLT 04:

The writer:
 - MUST include amt_to_forward and outgoing_cltv_value for every node.
 - MUST include short_channel_id for every non-final node.
 - MUST NOT include short_channel_id for the final node.
This commit is contained in:
Conner Fromknecht 2019-09-05 06:05:38 -07:00
parent aefec9b10f
commit 6015567927
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
2 changed files with 177 additions and 2 deletions

@ -2,6 +2,7 @@ package hop
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lightning-onion"
@ -10,6 +11,37 @@ import (
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
) )
// ErrInvalidPayload is an error returned when a parsed onion payload either
// included or omitted incorrect records for a particular hop type.
type ErrInvalidPayload struct {
// Type the record's type that cause the violation.
Type tlv.Type
// Ommitted if true, signals that the sender did not include the record.
// Otherwise, the sender included the record when it shouldn't have.
Omitted bool
// FinalHop if true, indicates that the violation is for the final hop
// in the route (identified by next hop id), otherwise the violation is
// for an intermediate hop.
FinalHop bool
}
// Error returns a human-readable description of the invalid payload error.
func (e ErrInvalidPayload) Error() string {
hopType := "intermediate"
if e.FinalHop {
hopType = "final"
}
violation := "included"
if e.Omitted {
violation = "omitted"
}
return fmt.Sprintf("onion payload for %s hop %s record with type %d",
hopType, violation, e.Type)
}
// Payload encapsulates all information delivered to a hop in an onion payload. // Payload encapsulates all information delivered to a hop in an onion payload.
// A Hop can represent either a TLV or legacy payload. The primary forwarding // A Hop can represent either a TLV or legacy payload. The primary forwarding
// instruction can be accessed via ForwardingInfo, and additional records can be // instruction can be accessed via ForwardingInfo, and additional records can be
@ -53,7 +85,16 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
return nil, err return nil, err
} }
_, err = tlvStream.DecodeWithParsedTypes(r) parsedTypes, err := tlvStream.DecodeWithParsedTypes(r)
if err != nil {
return nil, err
}
nextHop := lnwire.NewShortChanIDFromInt(cid)
// Validate whether the sender properly included or omitted tlv records
// in accordance with BOLT 04.
err = ValidateParsedPayloadTypes(parsedTypes, nextHop)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,7 +102,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
return &Payload{ return &Payload{
FwdInfo: ForwardingInfo{ FwdInfo: ForwardingInfo{
Network: BitcoinNetwork, Network: BitcoinNetwork,
NextHop: lnwire.NewShortChanIDFromInt(cid), NextHop: nextHop,
AmountToForward: lnwire.MilliSatoshi(amt), AmountToForward: lnwire.MilliSatoshi(amt),
OutgoingCTLV: cltv, OutgoingCTLV: cltv,
}, },
@ -73,3 +114,48 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
func (h *Payload) ForwardingInfo() ForwardingInfo { func (h *Payload) ForwardingInfo() ForwardingInfo {
return h.FwdInfo return h.FwdInfo
} }
// ValidateParsedPayloadTypes checks the types parsed from a hop payload to
// ensure that the proper fields are either included or omitted. The finalHop
// boolean should be true if the payload was parsed for an exit hop. The
// requirements for this method are described in BOLT 04.
func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet,
nextHop lnwire.ShortChannelID) error {
isFinalHop := nextHop == Exit
_, hasAmt := parsedTypes[record.AmtOnionType]
_, hasLockTime := parsedTypes[record.LockTimeOnionType]
_, hasNextHop := parsedTypes[record.NextHopOnionType]
switch {
// All hops must include an amount to forward.
case !hasAmt:
return ErrInvalidPayload{
Type: record.AmtOnionType,
Omitted: true,
FinalHop: isFinalHop,
}
// All hops must include a cltv expiry.
case !hasLockTime:
return ErrInvalidPayload{
Type: record.LockTimeOnionType,
Omitted: true,
FinalHop: isFinalHop,
}
// The exit hop should omit the next hop id. If nextHop != Exit, the
// sender must have included a record, so we don't need to test for its
// inclusion at intermediate hops directly.
case isFinalHop && hasNextHop:
return ErrInvalidPayload{
Type: record.NextHopOnionType,
Omitted: false,
FinalHop: true,
}
}
return nil
}

@ -0,0 +1,89 @@
package hop_test
import (
"bytes"
"reflect"
"testing"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/record"
)
type decodePayloadTest struct {
name string
payload []byte
expErr error
}
var decodePayloadTests = []decodePayloadTest{
{
name: "final hop no amount",
payload: []byte{0x04, 0x00},
expErr: hop.ErrInvalidPayload{
Type: record.AmtOnionType,
Omitted: true,
FinalHop: true,
},
},
{
name: "intermediate hop no amount",
payload: []byte{0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: record.AmtOnionType,
Omitted: true,
FinalHop: false,
},
},
{
name: "final hop no expiry",
payload: []byte{0x02, 0x00},
expErr: hop.ErrInvalidPayload{
Type: record.LockTimeOnionType,
Omitted: true,
FinalHop: true,
},
},
{
name: "intermediate hop no expiry",
payload: []byte{0x02, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: record.LockTimeOnionType,
Omitted: true,
FinalHop: false,
},
},
{
name: "final hop next sid present",
payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
},
expErr: hop.ErrInvalidPayload{
Type: record.NextHopOnionType,
Omitted: false,
FinalHop: true,
},
},
}
// TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the
// tests yields the expected errors depending on whether the proper fields were
// included or omitted.
func TestDecodeHopPayloadRecordValidation(t *testing.T) {
for _, test := range decodePayloadTests {
t.Run(test.name, func(t *testing.T) {
testDecodeHopPayloadValidation(t, test)
})
}
}
func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
_, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload))
if !reflect.DeepEqual(test.expErr, err) {
t.Fatalf("expected error mismatch, want: %v, got: %v",
test.expErr, err)
}
}