From 70708e2e71879394c445cb7696df577d5e96ea34 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Mon, 4 Nov 2019 15:10:15 -0800 Subject: [PATCH] htlcswitch: return hop.Payload from HopIterator --- htlcswitch/hop/iterator.go | 39 ++++++++++++++------------------- htlcswitch/hop/iterator_test.go | 3 ++- htlcswitch/link.go | 4 +++- htlcswitch/link_test.go | 19 ++++++++-------- htlcswitch/mock.go | 27 ++++++++++++++--------- htlcswitch/test_utils.go | 31 ++++++++++++++------------ 6 files changed, 66 insertions(+), 57 deletions(-) diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index 062ca879..bf9f7b48 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -16,12 +16,13 @@ import ( // interpret the forwarding information encoded within the HTLC packet, and hop // to encode the forwarding information for the _next_ hop. type Iterator interface { - // ForwardingInstructions returns the set of fields that detail exactly - // _how_ this hop should forward the HTLC to the next hop. - // Additionally, the information encoded within the returned - // ForwardingInfo is to be used by each hop to authenticate the - // information given to it by the prior hop. - ForwardingInstructions() (ForwardingInfo, error) + // HopPayload returns the set of fields that detail exactly _how_ this + // hop should forward the HTLC to the next hop. Additionally, the + // information encoded within the returned ForwardingInfo is to be used + // by each hop to authenticate the information given to it by the prior + // hop. The payload will also contain any additional TLV fields provided + // by the sender. + HopPayload() (*Payload, error) // ExtraOnionBlob returns the additional EOB data (if available). ExtraOnionBlob() []byte @@ -72,37 +73,31 @@ func (r *sphinxHopIterator) EncodeNextHop(w io.Writer) error { return r.processedPacket.NextPacket.Encode(w) } -// ForwardingInstructions returns the set of fields that detail exactly _how_ -// this hop should forward the HTLC to the next hop. Additionally, the -// information encoded within the returned ForwardingInfo is to be used by each -// hop to authenticate the information given to it by the prior hop. +// HopPayload returns the set of fields that detail exactly _how_ this hop +// should forward the HTLC to the next hop. Additionally, the information +// encoded within the returned ForwardingInfo is to be used by each hop to +// authenticate the information given to it by the prior hop. The payload will +// also contain any additional TLV fields provided by the sender. // // NOTE: Part of the HopIterator interface. -func (r *sphinxHopIterator) ForwardingInstructions() (ForwardingInfo, error) { +func (r *sphinxHopIterator) HopPayload() (*Payload, error) { switch r.processedPacket.Payload.Type { + // If this is the legacy payload, then we'll extract the information // directly from the pre-populated ForwardingInstructions field. case sphinx.PayloadLegacy: fwdInst := r.processedPacket.ForwardingInstructions - p := NewLegacyPayload(fwdInst) - - return p.ForwardingInfo(), nil + return NewLegacyPayload(fwdInst), nil // Otherwise, if this is the TLV payload, then we'll make a new stream // to decode only what we need to make routing decisions. case sphinx.PayloadTLV: - p, err := NewPayloadFromReader(bytes.NewReader( + return NewPayloadFromReader(bytes.NewReader( r.processedPacket.Payload.Payload, )) - if err != nil { - return ForwardingInfo{}, err - } - - return p.ForwardingInfo(), nil default: - return ForwardingInfo{}, fmt.Errorf("unknown "+ - "sphinx payload type: %v", + return nil, fmt.Errorf("unknown sphinx payload type: %v", r.processedPacket.Payload.Type) } } diff --git a/htlcswitch/hop/iterator_test.go b/htlcswitch/hop/iterator_test.go index 822e8794..20c5632b 100644 --- a/htlcswitch/hop/iterator_test.go +++ b/htlcswitch/hop/iterator_test.go @@ -85,12 +85,13 @@ func TestSphinxHopIteratorForwardingInstructions(t *testing.T) { for i, testCase := range testCases { iterator.processedPacket = testCase.sphinxPacket - fwdInfo, err := iterator.ForwardingInstructions() + pld, err := iterator.HopPayload() if err != nil { t.Fatalf("#%v: unable to extract forwarding "+ "instructions: %v", i, err) } + fwdInfo := pld.ForwardingInfo() if fwdInfo != testCase.expectedFwdInfo { t.Fatalf("#%v: wrong fwding info: expected %v, got %v", i, spew.Sdump(testCase.expectedFwdInfo), diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 7209789d..4593d042 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2642,7 +2642,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, heightNow := l.cfg.Switch.BestHeight() - fwdInfo, err := chanIterator.ForwardingInstructions() + pld, err := chanIterator.HopPayload() if err != nil { // If we're unable to process the onion payload, or we // received invalid onion payload failure, then we @@ -2671,6 +2671,8 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, continue } + fwdInfo := pld.ForwardingInfo() + switch fwdInfo.NextHop { case hop.Exit: updated, err := l.processExitHop( diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 13dd7869..93eb3763 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -22,6 +22,7 @@ import ( "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" @@ -563,7 +564,7 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) { // per-hop payload for outgoing time lock to be the incorrect value. // The proper value of the outgoing CLTV should be the policy set by // the receiving node, instead we set it to be a random value. - hops[0].OutgoingCTLV = 500 + hops[0].FwdInfo.OutgoingCTLV = 500 firstHop := n.firstBobChannelLink.ShortChanID() _, err = makePayment( n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt, @@ -616,7 +617,7 @@ func TestExitNodeAmountPayloadMismatch(t *testing.T) { // per-hop payload for amount to be the incorrect value. The proper // value of the amount to forward should be the amount that the // receiving node expects to receive. - hops[0].AmountToForward = 1 + hops[0].FwdInfo.AmountToForward = 1 firstHop := n.firstBobChannelLink.ShortChanID() _, err = makePayment( n.aliceServer, n.bobServer, firstHop, hops, amount, htlcAmt, @@ -4354,13 +4355,13 @@ func generateHtlcAndInvoice(t *testing.T, htlcAmt := lnwire.NewMSatFromSatoshis(10000) htlcExpiry := testStartingHeight + testInvoiceCltvExpiry - hops := []hop.ForwardingInfo{ - { - Network: hop.BitcoinNetwork, - NextHop: hop.Exit, - AmountToForward: htlcAmt, - OutgoingCTLV: uint32(htlcExpiry), - }, + hops := []*hop.Payload{ + hop.NewLegacyPayload(&sphinx.HopData{ + Realm: [1]byte{}, // hop.BitcoinNetwork + NextAddress: [8]byte{}, // hop.Exit, + ForwardAmount: uint64(htlcAmt), + OutgoingCltv: uint32(htlcExpiry), + }), } blob, err := generateRoute(hops...) if err != nil { diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 69bbec90..7701846e 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -265,16 +265,14 @@ func (s *mockServer) QuitSignal() <-chan struct{} { // mockHopIterator represents the test version of hop iterator which instead // of encrypting the path in onion blob just stores the path as a list of hops. type mockHopIterator struct { - hops []hop.ForwardingInfo + hops []*hop.Payload } -func newMockHopIterator(hops ...hop.ForwardingInfo) hop.Iterator { +func newMockHopIterator(hops ...*hop.Payload) hop.Iterator { return &mockHopIterator{hops: hops} } -func (r *mockHopIterator) ForwardingInstructions() ( - hop.ForwardingInfo, error) { - +func (r *mockHopIterator) HopPayload() (*hop.Payload, error) { h := r.hops[0] r.hops = r.hops[1:] return h, nil @@ -300,7 +298,8 @@ func (r *mockHopIterator) EncodeNextHop(w io.Writer) error { } for _, hop := range r.hops { - if err := encodeFwdInfo(w, &hop); err != nil { + fwdInfo := hop.ForwardingInfo() + if err := encodeFwdInfo(w, &fwdInfo); err != nil { return err } } @@ -434,14 +433,22 @@ func (p *mockIteratorDecoder) DecodeHopIterator(r io.Reader, rHash []byte, } hopLength := binary.BigEndian.Uint32(b[:]) - hops := make([]hop.ForwardingInfo, hopLength) + hops := make([]*hop.Payload, hopLength) for i := uint32(0); i < hopLength; i++ { - f := &hop.ForwardingInfo{} - if err := decodeFwdInfo(r, f); err != nil { + var f hop.ForwardingInfo + if err := decodeFwdInfo(r, &f); err != nil { return nil, lnwire.CodeTemporaryChannelFailure } - hops[i] = *f + var nextHopBytes [8]byte + binary.BigEndian.PutUint64(nextHopBytes[:], f.NextHop.ToUint64()) + + hops[i] = hop.NewLegacyPayload(&sphinx.HopData{ + Realm: [1]byte{}, // hop.BitcoinNetwork + NextAddress: nextHopBytes, + ForwardAmount: uint64(f.AmountToForward), + OutgoingCltv: f.OutgoingCTLV, + }) } return newMockHopIterator(hops...), lnwire.CodeNone diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 6c424b43..44085f90 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -23,6 +23,7 @@ import ( "github.com/btcsuite/fastsha256" "github.com/coreos/bbolt" "github.com/go-errors/errors" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -601,7 +602,7 @@ func generatePayment(invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, } // generateRoute generates the path blob by given array of peers. -func generateRoute(hops ...hop.ForwardingInfo) ( +func generateRoute(hops ...*hop.Payload) ( [lnwire.OnionPacketSize]byte, error) { var blob [lnwire.OnionPacketSize]byte @@ -642,13 +643,12 @@ type threeHopNetwork struct { // also the time lock value needed to route an HTLC with the target amount over // the specified path. func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32, - path ...*channelLink) (lnwire.MilliSatoshi, uint32, - []hop.ForwardingInfo) { + path ...*channelLink) (lnwire.MilliSatoshi, uint32, []*hop.Payload) { totalTimelock := startingHeight runningAmt := payAmt - hops := make([]hop.ForwardingInfo, len(path)) + hops := make([]*hop.Payload, len(path)) for i := len(path) - 1; i >= 0; i-- { // If this is the last hop, then the next hop is the special // "exit node". Otherwise, we look to the "prior" hop. @@ -676,7 +676,7 @@ func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32, amount := payAmt if i != len(path)-1 { prevHop := hops[i+1] - prevAmount := prevHop.AmountToForward + prevAmount := prevHop.ForwardingInfo().AmountToForward fee := ExpectedFee(path[i].cfg.FwrdingPolicy, prevAmount) runningAmt += fee @@ -687,12 +687,15 @@ func generateHops(payAmt lnwire.MilliSatoshi, startingHeight uint32, amount = runningAmt - fee } - hops[i] = hop.ForwardingInfo{ - Network: hop.BitcoinNetwork, - NextHop: nextHop, - AmountToForward: amount, - OutgoingCTLV: timeLock, - } + var nextHopBytes [8]byte + binary.BigEndian.PutUint64(nextHopBytes[:], nextHop.ToUint64()) + + hops[i] = hop.NewLegacyPayload(&sphinx.HopData{ + Realm: [1]byte{}, // hop.BitcoinNetwork + NextAddress: nextHopBytes, + ForwardAmount: uint64(amount), + OutgoingCltv: timeLock, + }) } return runningAmt, totalTimelock, hops @@ -739,7 +742,7 @@ func waitForPayFuncResult(payFunc func() error, d time.Duration) error { // * from Alice to Carol through the Bob // * from Alice to some another peer through the Bob func makePayment(sendingPeer, receivingPeer lnpeer.Peer, - firstHop lnwire.ShortChannelID, hops []hop.ForwardingInfo, + firstHop lnwire.ShortChannelID, hops []*hop.Payload, invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32) *paymentResponse { @@ -773,7 +776,7 @@ func makePayment(sendingPeer, receivingPeer lnpeer.Peer, // preparePayment creates an invoice at the receivingPeer and returns a function // that, when called, launches the payment from the sendingPeer. func preparePayment(sendingPeer, receivingPeer lnpeer.Peer, - firstHop lnwire.ShortChannelID, hops []hop.ForwardingInfo, + firstHop lnwire.ShortChannelID, hops []*hop.Payload, invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32) (*channeldb.Invoice, func() error, error) { @@ -1265,7 +1268,7 @@ func (n *twoHopNetwork) stop() { } func (n *twoHopNetwork) makeHoldPayment(sendingPeer, receivingPeer lnpeer.Peer, - firstHop lnwire.ShortChannelID, hops []hop.ForwardingInfo, + firstHop lnwire.ShortChannelID, hops []*hop.Payload, invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32, preimage lntypes.Preimage) chan error {