diff --git a/record/mpp.go b/record/mpp.go index 38877caf..6e260d54 100644 --- a/record/mpp.go +++ b/record/mpp.go @@ -98,6 +98,11 @@ func (r *MPP) Record() tlv.Record { ) } +// PayloadSize returns the size this record takes up in encoded form. +func (r *MPP) PayloadSize() uint64 { + return 32 + tlv.SizeTUint64(uint64(r.totalMsat)) +} + // String returns a human-readable representation of the mpp payload field. func (r *MPP) String() string { return fmt.Sprintf("total=%v, addr=%x", r.totalMsat, r.paymentAddr) diff --git a/routing/route/route.go b/routing/route/route.go index ff043ef8..90a407e3 100644 --- a/routing/route/route.go +++ b/routing/route/route.go @@ -184,6 +184,53 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error { return tlvStream.Encode(w) } +// Size returns the total size this hop's payload would take up in the onion +// packet. +func (h *Hop) PayloadSize(nextChanID uint64) uint64 { + if h.LegacyPayload { + return sphinx.LegacyHopDataSize + } + + var payloadSize uint64 + + addRecord := func(tlvType tlv.Type, length uint64) { + payloadSize += tlv.VarIntSize(uint64(tlvType)) + + tlv.VarIntSize(length) + length + } + + // Add amount size. + addRecord(record.AmtOnionType, tlv.SizeTUint64(uint64(h.AmtToForward))) + + // Add lock time size. + addRecord( + record.LockTimeOnionType, + tlv.SizeTUint64(uint64(h.OutgoingTimeLock)), + ) + + // Add next hop if present. + if nextChanID != 0 { + addRecord(record.NextHopOnionType, 8) + } + + // Add mpp if present. + if h.MPP != nil { + addRecord(record.MPPOnionType, h.MPP.PayloadSize()) + } + + // Add custom records. + for k, v := range h.CustomRecords { + addRecord(tlv.Type(k), uint64(len(v))) + } + + // Add the size required to encode the payload length. + payloadSize += tlv.VarIntSize(payloadSize) + + // Add HMAC. + payloadSize += sphinx.HMACSize + + return payloadSize +} + // Route represents a path through the channel graph which runs over one or // more channels in succession. This struct carries all the information // required to craft the Sphinx onion packet, and send the payment along the diff --git a/routing/route/route_test.go b/routing/route/route_test.go index 2894cbdc..6c32d8be 100644 --- a/routing/route/route_test.go +++ b/routing/route/route_test.go @@ -2,12 +2,20 @@ package route import ( "bytes" + "encoding/hex" "testing" + "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" ) +var ( + testPrivKeyBytes, _ = hex.DecodeString("e126f68f7eafcc8b74f54d269fe206be715000f94dac067d1c04a8ca3b2db734") + _, testPubKey = btcec.PrivKeyFromBytes(btcec.S256(), testPrivKeyBytes) + testPubKeyBytes, _ = NewVertexFromBytes(testPubKey.SerializeCompressed()) +) + // TestRouteTotalFees checks that a route reports the expected total fee. func TestRouteTotalFees(t *testing.T) { t.Parallel() @@ -56,7 +64,6 @@ func TestRouteTotalFees(t *testing.T) { if r.TotalFees() != fee { t.Fatalf("expected %v fees, got %v", fee, r.TotalFees()) } - } var ( @@ -93,3 +100,57 @@ func TestMPPHop(t *testing.T) { t.Fatalf("expected err: %v, got: %v", nil, err) } } + +// TestPayloadSize tests the payload size calculation that is provided by Hop +// structs. +func TestPayloadSize(t *testing.T) { + hops := []*Hop{ + { + PubKeyBytes: testPubKeyBytes, + AmtToForward: 1000, + OutgoingTimeLock: 600000, + ChannelID: 3432483437438, + LegacyPayload: true, + }, + { + PubKeyBytes: testPubKeyBytes, + AmtToForward: 1200, + OutgoingTimeLock: 700000, + ChannelID: 63584534844, + }, + { + PubKeyBytes: testPubKeyBytes, + AmtToForward: 1200, + OutgoingTimeLock: 700000, + MPP: record.NewMPP(500, [32]byte{}), + CustomRecords: map[uint64][]byte{ + 100000: {1, 2, 3}, + 1000000: {4, 5}, + }, + }, + } + + rt := Route{ + Hops: hops, + } + path, err := rt.ToSphinxPath() + if err != nil { + t.Fatal(err) + } + + for i, onionHop := range path[:path.TrueRouteLength()] { + hop := hops[i] + var nextChan uint64 + if i < len(hops)-1 { + nextChan = hops[i+1].ChannelID + } + + expected := uint64(onionHop.HopPayload.NumBytes()) + actual := hop.PayloadSize(nextChan) + if expected != actual { + t.Fatalf("unexpected payload size at hop %v: "+ + "expected %v, got %v", + i, expected, actual) + } + } +} diff --git a/tlv/varint.go b/tlv/varint.go index 3888bfcb..38c7a7cd 100644 --- a/tlv/varint.go +++ b/tlv/varint.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "errors" "io" + + "github.com/btcsuite/btcd/wire" ) // ErrVarIntNotCanonical signals that the decoded varint was not minimally encoded. @@ -107,3 +109,8 @@ func WriteVarInt(w io.Writer, val uint64, buf *[8]byte) error { _, err := w.Write(buf[:length]) return err } + +// VarIntSize returns the required number of bytes to encode a var int. +func VarIntSize(val uint64) uint64 { + return uint64(wire.VarIntSerializeSize(val)) +}