htlcswitch: return hop.Payload from HopIterator

This commit is contained in:
Conner Fromknecht 2019-11-04 15:10:15 -08:00
parent 4a6f5d8d3d
commit 70708e2e71
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
6 changed files with 66 additions and 57 deletions

@ -16,12 +16,13 @@ import (
// interpret the forwarding information encoded within the HTLC packet, and hop // interpret the forwarding information encoded within the HTLC packet, and hop
// to encode the forwarding information for the _next_ hop. // to encode the forwarding information for the _next_ hop.
type Iterator interface { type Iterator interface {
// ForwardingInstructions returns the set of fields that detail exactly // HopPayload returns the set of fields that detail exactly _how_ this
// _how_ this hop should forward the HTLC to the next hop. // hop should forward the HTLC to the next hop. Additionally, the
// Additionally, the information encoded within the returned // information encoded within the returned ForwardingInfo is to be used
// ForwardingInfo is to be used by each hop to authenticate the // by each hop to authenticate the information given to it by the prior
// information given to it by the prior hop. // hop. The payload will also contain any additional TLV fields provided
ForwardingInstructions() (ForwardingInfo, error) // by the sender.
HopPayload() (*Payload, error)
// ExtraOnionBlob returns the additional EOB data (if available). // ExtraOnionBlob returns the additional EOB data (if available).
ExtraOnionBlob() []byte ExtraOnionBlob() []byte
@ -72,37 +73,31 @@ func (r *sphinxHopIterator) EncodeNextHop(w io.Writer) error {
return r.processedPacket.NextPacket.Encode(w) return r.processedPacket.NextPacket.Encode(w)
} }
// ForwardingInstructions returns the set of fields that detail exactly _how_ // HopPayload returns the set of fields that detail exactly _how_ this hop
// this hop should forward the HTLC to the next hop. Additionally, the // should forward the HTLC to the next hop. Additionally, the information
// information encoded within the returned ForwardingInfo is to be used by each // encoded within the returned ForwardingInfo is to be used by each hop to
// hop to authenticate the information given to it by the prior hop. // authenticate the information given to it by the prior hop. The payload will
// also contain any additional TLV fields provided by the sender.
// //
// NOTE: Part of the HopIterator interface. // NOTE: Part of the HopIterator interface.
func (r *sphinxHopIterator) ForwardingInstructions() (ForwardingInfo, error) { func (r *sphinxHopIterator) HopPayload() (*Payload, error) {
switch r.processedPacket.Payload.Type { switch r.processedPacket.Payload.Type {
// If this is the legacy payload, then we'll extract the information // If this is the legacy payload, then we'll extract the information
// directly from the pre-populated ForwardingInstructions field. // directly from the pre-populated ForwardingInstructions field.
case sphinx.PayloadLegacy: case sphinx.PayloadLegacy:
fwdInst := r.processedPacket.ForwardingInstructions fwdInst := r.processedPacket.ForwardingInstructions
p := NewLegacyPayload(fwdInst) return NewLegacyPayload(fwdInst), nil
return p.ForwardingInfo(), nil
// Otherwise, if this is the TLV payload, then we'll make a new stream // Otherwise, if this is the TLV payload, then we'll make a new stream
// to decode only what we need to make routing decisions. // to decode only what we need to make routing decisions.
case sphinx.PayloadTLV: case sphinx.PayloadTLV:
p, err := NewPayloadFromReader(bytes.NewReader( return NewPayloadFromReader(bytes.NewReader(
r.processedPacket.Payload.Payload, r.processedPacket.Payload.Payload,
)) ))
if err != nil {
return ForwardingInfo{}, err
}
return p.ForwardingInfo(), nil
default: default:
return ForwardingInfo{}, fmt.Errorf("unknown "+ return nil, fmt.Errorf("unknown sphinx payload type: %v",
"sphinx payload type: %v",
r.processedPacket.Payload.Type) r.processedPacket.Payload.Type)
} }
} }

@ -85,12 +85,13 @@ func TestSphinxHopIteratorForwardingInstructions(t *testing.T) {
for i, testCase := range testCases { for i, testCase := range testCases {
iterator.processedPacket = testCase.sphinxPacket iterator.processedPacket = testCase.sphinxPacket
fwdInfo, err := iterator.ForwardingInstructions() pld, err := iterator.HopPayload()
if err != nil { if err != nil {
t.Fatalf("#%v: unable to extract forwarding "+ t.Fatalf("#%v: unable to extract forwarding "+
"instructions: %v", i, err) "instructions: %v", i, err)
} }
fwdInfo := pld.ForwardingInfo()
if fwdInfo != testCase.expectedFwdInfo { if fwdInfo != testCase.expectedFwdInfo {
t.Fatalf("#%v: wrong fwding info: expected %v, got %v", t.Fatalf("#%v: wrong fwding info: expected %v, got %v",
i, spew.Sdump(testCase.expectedFwdInfo), i, spew.Sdump(testCase.expectedFwdInfo),

@ -2642,7 +2642,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
heightNow := l.cfg.Switch.BestHeight() heightNow := l.cfg.Switch.BestHeight()
fwdInfo, err := chanIterator.ForwardingInstructions() pld, err := chanIterator.HopPayload()
if err != nil { if err != nil {
// If we're unable to process the onion payload, or we // If we're unable to process the onion payload, or we
// received invalid onion payload failure, then we // received invalid onion payload failure, then we
@ -2671,6 +2671,8 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
continue continue
} }
fwdInfo := pld.ForwardingInfo()
switch fwdInfo.NextHop { switch fwdInfo.NextHop {
case hop.Exit: case hop.Exit:
updated, err := l.processExitHop( updated, err := l.processExitHop(

@ -22,6 +22,7 @@ import (
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors" "github.com/go-errors/errors"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/build"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
@ -563,7 +564,7 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) {
// per-hop payload for outgoing time lock to be the incorrect value. // per-hop payload for outgoing time lock to be the incorrect value.
// The proper value of the outgoing CLTV should be the policy set by // The proper value of the outgoing CLTV should be the policy set by
// the receiving node, instead we set it to be a random value. // the receiving node, instead we set it to be a random value.
hops[0].OutgoingCTLV = 500 hops[0].FwdInfo.OutgoingCTLV = 500
firstHop := n.firstBobChannelLink.ShortChanID() firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment( _, err = makePayment(
n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt, n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt,
@ -616,7 +617,7 @@ func TestExitNodeAmountPayloadMismatch(t *testing.T) {
// per-hop payload for amount to be the incorrect value. The proper // per-hop payload for amount to be the incorrect value. The proper
// value of the amount to forward should be the amount that the // value of the amount to forward should be the amount that the
// receiving node expects to receive. // receiving node expects to receive.
hops[0].AmountToForward = 1 hops[0].FwdInfo.AmountToForward = 1
firstHop := n.firstBobChannelLink.ShortChanID() firstHop := n.firstBobChannelLink.ShortChanID()
_, err = makePayment( _, err = makePayment(
n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt, n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt,
@ -4354,13 +4355,13 @@ func generateHtlcAndInvoice(t *testing.T,
htlcAmt := lnwire.NewMSatFromSatoshis(10000) htlcAmt := lnwire.NewMSatFromSatoshis(10000)
htlcExpiry := testStartingHeight + testInvoiceCltvExpiry htlcExpiry := testStartingHeight + testInvoiceCltvExpiry
hops := []hop.ForwardingInfo{ hops := []*hop.Payload{
{ hop.NewLegacyPayload(&sphinx.HopData{
Network: hop.BitcoinNetwork, Realm: [1]byte{}, // hop.BitcoinNetwork
NextHop: hop.Exit, NextAddress: [8]byte{}, // hop.Exit,
AmountToForward: htlcAmt, ForwardAmount: uint64(htlcAmt),
OutgoingCTLV: uint32(htlcExpiry), OutgoingCltv: uint32(htlcExpiry),
}, }),
} }
blob, err := generateRoute(hops...) blob, err := generateRoute(hops...)
if err != nil { if err != nil {

@ -265,16 +265,14 @@ func (s *mockServer) QuitSignal() <-chan struct{} {
// mockHopIterator represents the test version of hop iterator which instead // mockHopIterator represents the test version of hop iterator which instead
// of encrypting the path in onion blob just stores the path as a list of hops. // of encrypting the path in onion blob just stores the path as a list of hops.
type mockHopIterator struct { type mockHopIterator struct {
hops []hop.ForwardingInfo hops []*hop.Payload
} }
func newMockHopIterator(hops ...hop.ForwardingInfo) hop.Iterator { func newMockHopIterator(hops ...*hop.Payload) hop.Iterator {
return &mockHopIterator{hops: hops} return &mockHopIterator{hops: hops}
} }
func (r *mockHopIterator) ForwardingInstructions() ( func (r *mockHopIterator) HopPayload() (*hop.Payload, error) {
hop.ForwardingInfo, error) {
h := r.hops[0] h := r.hops[0]
r.hops = r.hops[1:] r.hops = r.hops[1:]
return h, nil return h, nil
@ -300,7 +298,8 @@ func (r *mockHopIterator) EncodeNextHop(w io.Writer) error {
} }
for _, hop := range r.hops { for _, hop := range r.hops {
if err := encodeFwdInfo(w, &hop); err != nil { fwdInfo := hop.ForwardingInfo()
if err := encodeFwdInfo(w, &fwdInfo); err != nil {
return err return err
} }
} }
@ -434,14 +433,22 @@ func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte,
} }
hopLength := binary.BigEndian.Uint32(b[:]) hopLength := binary.BigEndian.Uint32(b[:])
hops := make([]hop.ForwardingInfo, hopLength) hops := make([]*hop.Payload, hopLength)
for i := uint32(0); i < hopLength; i++ { for i := uint32(0); i < hopLength; i++ {
f := &hop.ForwardingInfo{} var f hop.ForwardingInfo
if err := decodeFwdInfo(r, f); err != nil { if err := decodeFwdInfo(r, &f); err != nil {
return nil, lnwire.CodeTemporaryChannelFailure return nil, lnwire.CodeTemporaryChannelFailure
} }
hops[i] = *f var nextHopBytes [8]byte
binary.BigEndian.PutUint64(nextHopBytes[:], f.NextHop.ToUint64())
hops[i] = hop.NewLegacyPayload(&sphinx.HopData{
Realm: [1]byte{}, // hop.BitcoinNetwork
NextAddress: nextHopBytes,
ForwardAmount: uint64(f.AmountToForward),
OutgoingCltv: f.OutgoingCTLV,
})
} }
return newMockHopIterator(hops...), lnwire.CodeNone return newMockHopIterator(hops...), lnwire.CodeNone

@ -23,6 +23,7 @@ import (
"github.com/btcsuite/fastsha256" "github.com/btcsuite/fastsha256"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/go-errors/errors" "github.com/go-errors/errors"
sphinx "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
@ -601,7 +602,7 @@ func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32,
} }
// generateRoute generates the path blob by given array of peers. // generateRoute generates the path blob by given array of peers.
func generateRoute(hops ...hop.ForwardingInfo) ( func generateRoute(hops ...*hop.Payload) (
[lnwire.OnionPacketSize]byte, error) { [lnwire.OnionPacketSize]byte, error) {
var blob [lnwire.OnionPacketSize]byte var blob [lnwire.OnionPacketSize]byte
@ -642,13 +643,12 @@ type threeHopNetwork struct {
// also the time lock value needed to route an HTLC with the target amount over // also the time lock value needed to route an HTLC with the target amount over
// the specified path. // the specified path.
func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32, func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32,
path ...*channelLink) (lnwire.MilliSatoshi, uint32, path ...*channelLink) (lnwire.MilliSatoshi, uint32, []*hop.Payload) {
[]hop.ForwardingInfo) {
totalTimelock := startingHeight totalTimelock := startingHeight
runningAmt := payAmt runningAmt := payAmt
hops := make([]hop.ForwardingInfo, len(path)) hops := make([]*hop.Payload, len(path))
for i := len(path) - 1; i >= 0; i-- { for i := len(path) - 1; i >= 0; i-- {
// If this is the last hop, then the next hop is the special // If this is the last hop, then the next hop is the special
// "exit node". Otherwise, we look to the "prior" hop. // "exit node". Otherwise, we look to the "prior" hop.
@ -676,7 +676,7 @@ func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32,
amount := payAmt amount := payAmt
if i != len(path)-1 { if i != len(path)-1 {
prevHop := hops[i+1] prevHop := hops[i+1]
prevAmount := prevHop.AmountToForward prevAmount := prevHop.ForwardingInfo().AmountToForward
fee := ExpectedFee(path[i].cfg.FwrdingPolicy, prevAmount) fee := ExpectedFee(path[i].cfg.FwrdingPolicy, prevAmount)
runningAmt += fee runningAmt += fee
@ -687,12 +687,15 @@ func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32,
amount = runningAmt - fee amount = runningAmt - fee
} }
hops[i] = hop.ForwardingInfo{ var nextHopBytes [8]byte
Network: hop.BitcoinNetwork, binary.BigEndian.PutUint64(nextHopBytes[:], nextHop.ToUint64())
NextHop: nextHop,
AmountToForward: amount, hops[i] = hop.NewLegacyPayload(&sphinx.HopData{
OutgoingCTLV: timeLock, Realm: [1]byte{}, // hop.BitcoinNetwork
} NextAddress: nextHopBytes,
ForwardAmount: uint64(amount),
OutgoingCltv: timeLock,
})
} }
return runningAmt, totalTimelock, hops return runningAmt, totalTimelock, hops
@ -739,7 +742,7 @@ func waitForPayFuncResult(payFunc func() error, d time.Duration) error {
// * from Alice to Carol through the Bob // * from Alice to Carol through the Bob
// * from Alice to some another peer through the Bob // * from Alice to some another peer through the Bob
func makePayment(sendingPeer, receivingPeer lnpeer.Peer, func makePayment(sendingPeer, receivingPeer lnpeer.Peer,
firstHop lnwire.ShortChannelID, hops []hop.ForwardingInfo, firstHop lnwire.ShortChannelID, hops []*hop.Payload,
invoiceAmt, htlcAmt lnwire.MilliSatoshi, invoiceAmt, htlcAmt lnwire.MilliSatoshi,
timelock uint32) *paymentResponse { timelock uint32) *paymentResponse {
@ -773,7 +776,7 @@ func makePayment(sendingPeer, receivingPeer lnpeer.Peer,
// preparePayment creates an invoice at the receivingPeer and returns a function // preparePayment creates an invoice at the receivingPeer and returns a function
// that, when called, launches the payment from the sendingPeer. // that, when called, launches the payment from the sendingPeer.
func preparePayment(sendingPeer, receivingPeer lnpeer.Peer, func preparePayment(sendingPeer, receivingPeer lnpeer.Peer,
firstHop lnwire.ShortChannelID, hops []hop.ForwardingInfo, firstHop lnwire.ShortChannelID, hops []*hop.Payload,
invoiceAmt, htlcAmt lnwire.MilliSatoshi, invoiceAmt, htlcAmt lnwire.MilliSatoshi,
timelock uint32) (*channeldb.Invoice, func() error, error) { timelock uint32) (*channeldb.Invoice, func() error, error) {
@ -1265,7 +1268,7 @@ func (n *twoHopNetwork) stop() {
} }
func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer, func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer,
firstHop lnwire.ShortChannelID, hops []hop.ForwardingInfo, firstHop lnwire.ShortChannelID, hops []*hop.Payload,
invoiceAmt, htlcAmt lnwire.MilliSatoshi, invoiceAmt, htlcAmt lnwire.MilliSatoshi,
timelock uint32, preimage lntypes.Preimage) chan error { timelock uint32, preimage lntypes.Preimage) chan error {