diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 9af694d6..54bccab5 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -166,7 +166,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) { // identical to HTLC resolution in the link. event, err := h.Registry.NotifyExitHopHtlc( h.payHash, h.htlcAmt, h.htlcExpiry, currentHeight, - hodlChan, + hodlChan, nil, ) switch err { case channeldb.ErrInvoiceNotFound: diff --git a/contractcourt/interfaces.go b/contractcourt/interfaces.go index a18ee3d6..3333e6b6 100644 --- a/contractcourt/interfaces.go +++ b/contractcourt/interfaces.go @@ -22,7 +22,8 @@ type Registry interface { // the resolution is sent on the passed in hodlChan later. NotifyExitHopHtlc(payHash lntypes.Hash, paidAmount lnwire.MilliSatoshi, expiry uint32, currentHeight int32, - hodlChan chan<- interface{}) (*invoices.HodlEvent, error) + hodlChan chan<- interface{}, + eob []byte) (*invoices.HodlEvent, error) // HodlUnsubscribeAll unsubscribes from all hodl events. HodlUnsubscribeAll(subscriber chan<- interface{}) diff --git a/contractcourt/mock_registry_test.go b/contractcourt/mock_registry_test.go index f54a1465..288ea5ba 100644 --- a/contractcourt/mock_registry_test.go +++ b/contractcourt/mock_registry_test.go @@ -23,7 +23,7 @@ type mockRegistry struct { func (r *mockRegistry) NotifyExitHopHtlc(payHash lntypes.Hash, paidAmount lnwire.MilliSatoshi, expiry uint32, currentHeight int32, - hodlChan chan<- interface{}) (*invoices.HodlEvent, error) { + hodlChan chan<- interface{}, eob []byte) (*invoices.HodlEvent, error) { r.notifyChan <- notifyExitHopData{ hodlChan: hodlChan, diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index 4dff21a5..52a4c194 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -23,10 +23,13 @@ type InvoiceDatabase interface { // invoice is a debug invoice, then this method is a noop as debug // invoices are never fully settled. The return value describes how the // htlc should be resolved. If the htlc cannot be resolved immediately, - // the resolution is sent on the passed in hodlChan later. + // the resolution is sent on the passed in hodlChan later. The eob + // field passes the entire onion hop payload into the invoice registry + // for decoding purposes. NotifyExitHopHtlc(payHash lntypes.Hash, paidAmount lnwire.MilliSatoshi, expiry uint32, currentHeight int32, - hodlChan chan<- interface{}) (*invoices.HodlEvent, error) + hodlChan chan<- interface{}, + eob []byte) (*invoices.HodlEvent, error) // CancelInvoice attempts to cancel the invoice corresponding to the // passed payment hash. diff --git a/htlcswitch/iterator.go b/htlcswitch/iterator.go index a068000d..6e3ae07d 100644 --- a/htlcswitch/iterator.go +++ b/htlcswitch/iterator.go @@ -1,12 +1,15 @@ package htlcswitch import ( + "bytes" "encoding/binary" + "fmt" "io" "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" ) // NetworkHop indicates the blockchain network that is intended to be the next @@ -85,7 +88,10 @@ type HopIterator interface { // Additionally, the information encoded within the returned // ForwardingInfo is to be used by each hop to authenticate the // information given to it by the prior hop. - ForwardingInstructions() ForwardingInfo + ForwardingInstructions() (ForwardingInfo, error) + + // ExtraOnionBlob returns the additional EOB data (if available). + ExtraOnionBlob() []byte // EncodeNextHop encodes the onion packet destined for the next hop // into the passed io.Writer. @@ -139,24 +145,79 @@ func (r *sphinxHopIterator) EncodeNextHop(w io.Writer) error { // hop to authenticate the information given to it by the prior hop. // // NOTE: Part of the HopIterator interface. -func (r *sphinxHopIterator) ForwardingInstructions() ForwardingInfo { - fwdInst := r.processedPacket.ForwardingInstructions +func (r *sphinxHopIterator) ForwardingInstructions() (ForwardingInfo, error) { + var ( + nextHop lnwire.ShortChannelID + amt uint64 + cltv uint32 + ) - var nextHop lnwire.ShortChannelID - switch r.processedPacket.Action { - case sphinx.ExitNode: - nextHop = exitHop - case sphinx.MoreHops: - s := binary.BigEndian.Uint64(fwdInst.NextAddress[:]) - nextHop = lnwire.NewShortChanIDFromInt(s) + switch r.processedPacket.Payload.Type { + // If this is the legacy payload, then we'll extract the information + // directly from the pre-populated ForwardingInstructions field. + case sphinx.PayloadLegacy: + fwdInst := r.processedPacket.ForwardingInstructions + + switch r.processedPacket.Action { + case sphinx.ExitNode: + nextHop = exitHop + case sphinx.MoreHops: + s := binary.BigEndian.Uint64(fwdInst.NextAddress[:]) + nextHop = lnwire.NewShortChanIDFromInt(s) + } + + amt = fwdInst.ForwardAmount + cltv = fwdInst.OutgoingCltv + + // Otherwise, if this is the TLV payload, then we'll make a new stream + // to decode only what we need to make routing decisions. + case sphinx.PayloadTLV: + var cid uint64 + + tlvStream, err := tlv.NewStream( + tlv.MakeDynamicRecord( + tlv.AmtOnionType, &amt, nil, + tlv.ETUint64, tlv.DTUint64, + ), + tlv.MakeDynamicRecord( + tlv.LockTimeOnionType, &cltv, nil, + tlv.ETUint32, tlv.DTUint32, + ), + tlv.MakePrimitiveRecord(tlv.NextHopOnionType, &cid), + ) + if err != nil { + return ForwardingInfo{}, err + } + + err = tlvStream.Decode(bytes.NewReader( + r.processedPacket.Payload.Payload, + )) + if err != nil { + return ForwardingInfo{}, err + } + + nextHop = lnwire.NewShortChanIDFromInt(cid) + + default: + return ForwardingInfo{}, fmt.Errorf("unknown sphinx payload "+ + "type: %v", r.processedPacket.Payload.Type) } return ForwardingInfo{ Network: BitcoinHop, NextHop: nextHop, - AmountToForward: lnwire.MilliSatoshi(fwdInst.ForwardAmount), - OutgoingCTLV: fwdInst.OutgoingCltv, + AmountToForward: lnwire.MilliSatoshi(amt), + OutgoingCTLV: cltv, + }, nil +} + +// ExtraOnionBlob returns the additional EOB data (if available). +func (r *sphinxHopIterator) ExtraOnionBlob() []byte { + if r.processedPacket.Payload.Type == sphinx.PayloadLegacy { + return nil } + + return r.processedPacket.Payload.Payload } // ExtractErrorEncrypter decodes and returns the ErrorEncrypter for this hop, diff --git a/htlcswitch/iterator_test.go b/htlcswitch/iterator_test.go new file mode 100644 index 00000000..01c28ed9 --- /dev/null +++ b/htlcswitch/iterator_test.go @@ -0,0 +1,109 @@ +package htlcswitch + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/davecgh/go-spew/spew" + sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// TestSphinxHopIteratorForwardingInstructions tests that we're able to +// properly decode an onion payload, no matter the payload type, into the +// original set of forwarding instructions. +func TestSphinxHopIteratorForwardingInstructions(t *testing.T) { + t.Parallel() + + // First, we'll make the hop data that the sender would create to send + // an HTLC through our imaginary route. + hopData := sphinx.HopData{ + ForwardAmount: 100000, + OutgoingCltv: 4343, + } + copy(hopData.NextAddress[:], bytes.Repeat([]byte("a"), 8)) + + // Next, we'll make the hop forwarding information that we should + // extract each type, no matter the payload type. + nextAddrInt := binary.BigEndian.Uint64(hopData.NextAddress[:]) + expectedFwdInfo := ForwardingInfo{ + NextHop: lnwire.NewShortChanIDFromInt(nextAddrInt), + AmountToForward: lnwire.MilliSatoshi(hopData.ForwardAmount), + OutgoingCTLV: hopData.OutgoingCltv, + } + + // For our TLV payload, we'll serialize the hop into into a TLV stream + // as we would normally in the routing network. + var b bytes.Buffer + tlvRecords := []tlv.Record{ + tlv.MakeDynamicRecord( + tlv.AmtOnionType, &hopData.ForwardAmount, func() uint64 { + return tlv.SizeTUint64(hopData.ForwardAmount) + }, + tlv.ETUint64, tlv.DTUint64, + ), + tlv.MakeDynamicRecord( + tlv.LockTimeOnionType, &hopData.OutgoingCltv, func() uint64 { + return tlv.SizeTUint32(hopData.OutgoingCltv) + }, + tlv.ETUint32, tlv.DTUint32, + ), + tlv.MakePrimitiveRecord(tlv.NextHopOnionType, &nextAddrInt), + } + tlvStream, err := tlv.NewStream(tlvRecords...) + if err != nil { + t.Fatalf("unable to create stream: %v", err) + } + if err := tlvStream.Encode(&b); err != nil { + t.Fatalf("unable to encode stream: %v", err) + } + + var testCases = []struct { + sphinxPacket *sphinx.ProcessedPacket + expectedFwdInfo ForwardingInfo + }{ + // A regular legacy payload that signals more hops. + { + sphinxPacket: &sphinx.ProcessedPacket{ + Payload: sphinx.HopPayload{ + Type: sphinx.PayloadLegacy, + }, + Action: sphinx.MoreHops, + ForwardingInstructions: &hopData, + }, + expectedFwdInfo: expectedFwdInfo, + }, + // A TLV payload, we can leave off the action as we'll always + // read the cid encoded. + { + sphinxPacket: &sphinx.ProcessedPacket{ + Payload: sphinx.HopPayload{ + Type: sphinx.PayloadTLV, + Payload: b.Bytes(), + }, + }, + expectedFwdInfo: expectedFwdInfo, + }, + } + + // Finally, we'll test that we get the same set of + // ForwardingInstructions for each payload type. + iterator := sphinxHopIterator{} + for i, testCase := range testCases { + iterator.processedPacket = testCase.sphinxPacket + + fwdInfo, err := iterator.ForwardingInstructions() + if err != nil { + t.Fatalf("#%v: unable to extract forwarding "+ + "instructions: %v", i, err) + } + + if fwdInfo != testCase.expectedFwdInfo { + t.Fatalf("#%v: wrong fwding info: expected %v, got %v", + i, spew.Sdump(testCase.expectedFwdInfo), + spew.Sdump(fwdInfo)) + } + } +} diff --git a/htlcswitch/link.go b/htlcswitch/link.go index bd79b621..7d14b7af 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -2627,8 +2627,9 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, // If we're unable to process the onion blob than we // should send the malformed htlc error to payment // sender. - l.sendMalformedHTLCError(pd.HtlcIndex, failureCode, - onionBlob[:], pd.SourceRef) + l.sendMalformedHTLCError( + pd.HtlcIndex, failureCode, onionBlob[:], pd.SourceRef, + ) needUpdate = true log.Errorf("unable to decode onion "+ @@ -2638,11 +2639,29 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, heightNow := l.cfg.Switch.BestHeight() - fwdInfo := chanIterator.ForwardingInstructions() + fwdInfo, err := chanIterator.ForwardingInstructions() + if err != nil { + // If we're unable to process the onion payload, or we + // we received malformed TLV stream, then we should + // send an error back to the caller so the HTLC can be + // cancelled. + l.sendHTLCError( + pd.HtlcIndex, + lnwire.NewInvalidOnionVersion(onionBlob[:]), + obfuscator, pd.SourceRef, + ) + needUpdate = true + + log.Errorf("Unable to decode forwarding "+ + "instructions: %v", err) + continue + } + switch fwdInfo.NextHop { case exitHop: updated, err := l.processExitHop( pd, obfuscator, fwdInfo, heightNow, + chanIterator.ExtraOnionBlob(), ) if err != nil { l.fail(LinkFailureError{code: ErrInternalError}, @@ -2814,8 +2833,8 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg, // processExitHop handles an htlc for which this link is the exit hop. It // returns a boolean indicating whether the commitment tx needs an update. func (l *channelLink) processExitHop(pd *lnwallet.PaymentDescriptor, - obfuscator ErrorEncrypter, fwdInfo ForwardingInfo, heightNow uint32) ( - bool, error) { + obfuscator ErrorEncrypter, fwdInfo ForwardingInfo, + heightNow uint32, eob []byte) (bool, error) { // If hodl.ExitSettle is requested, we will not validate the final hop's // ADD, nor will we settle the corresponding invoice or respond with the @@ -2861,7 +2880,7 @@ func (l *channelLink) processExitHop(pd *lnwallet.PaymentDescriptor, event, err := l.cfg.Registry.NotifyExitHopHtlc( invoiceHash, pd.Amount, pd.Timeout, int32(heightNow), - l.hodlQueue.ChanIn(), + l.hodlQueue.ChanIn(), eob, ) switch err { diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 9c5335a7..073e825d 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -275,10 +275,14 @@ func newMockHopIterator(hops ...ForwardingInfo) HopIterator { return &mockHopIterator{hops: hops} } -func (r *mockHopIterator) ForwardingInstructions() ForwardingInfo { +func (r *mockHopIterator) ForwardingInstructions() (ForwardingInfo, error) { h := r.hops[0] r.hops = r.hops[1:] - return h + return h, nil +} + +func (r *mockHopIterator) ExtraOnionBlob() []byte { + return nil } func (r *mockHopIterator) ExtractErrorEncrypter( @@ -789,10 +793,10 @@ func (i *mockInvoiceRegistry) SettleHodlInvoice(preimage lntypes.Preimage) error func (i *mockInvoiceRegistry) NotifyExitHopHtlc(rhash lntypes.Hash, amt lnwire.MilliSatoshi, expiry uint32, currentHeight int32, - hodlChan chan<- interface{}) (*invoices.HodlEvent, error) { + hodlChan chan<- interface{}, eob []byte) (*invoices.HodlEvent, error) { event, err := i.registry.NotifyExitHopHtlc( - rhash, amt, expiry, currentHeight, hodlChan, + rhash, amt, expiry, currentHeight, hodlChan, eob, ) if err != nil { return nil, err diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 46166c4d..a5570a3c 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -489,7 +489,7 @@ func (i *InvoiceRegistry) checkHtlcParameters(invoice *channeldb.Invoice, // prevent deadlock. func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash, amtPaid lnwire.MilliSatoshi, expiry uint32, currentHeight int32, - hodlChan chan<- interface{}) (*HodlEvent, error) { + hodlChan chan<- interface{}, eob []byte) (*HodlEvent, error) { i.Lock() defer i.Unlock() diff --git a/invoices/invoiceregistry_test.go b/invoices/invoiceregistry_test.go index ff3c04be..c72db33e 100644 --- a/invoices/invoiceregistry_test.go +++ b/invoices/invoiceregistry_test.go @@ -119,7 +119,7 @@ func TestSettleInvoice(t *testing.T) { // Settle invoice with a slightly higher amount. amtPaid := lnwire.MilliSatoshi(100500) _, err = registry.NotifyExitHopHtlc( - hash, amtPaid, testInvoiceExpiry, 0, hodlChan, + hash, amtPaid, testInvoiceExpiry, 0, hodlChan, nil, ) if err != nil { t.Fatal(err) @@ -155,6 +155,7 @@ func TestSettleInvoice(t *testing.T) { // restart. event, err := registry.NotifyExitHopHtlc( hash, amtPaid, testInvoiceExpiry, testCurrentHeight, hodlChan, + nil, ) if err != nil { t.Fatalf("unexpected NotifyExitHopHtlc error: %v", err) @@ -168,7 +169,7 @@ func TestSettleInvoice(t *testing.T) { // same. New HTLCs with a different amount should be rejected. event, err = registry.NotifyExitHopHtlc( hash, amtPaid+600, testInvoiceExpiry, testCurrentHeight, - hodlChan, + hodlChan, nil, ) if err != nil { t.Fatalf("unexpected NotifyExitHopHtlc error: %v", err) @@ -181,7 +182,7 @@ func TestSettleInvoice(t *testing.T) { // behaviour as settling with a higher amount. event, err = registry.NotifyExitHopHtlc( hash, amtPaid-600, testInvoiceExpiry, testCurrentHeight, - hodlChan, + hodlChan, nil, ) if err != nil { t.Fatalf("unexpected NotifyExitHopHtlc error: %v", err) @@ -304,7 +305,7 @@ func TestCancelInvoice(t *testing.T) { // succeed. hodlChan := make(chan interface{}) event, err := registry.NotifyExitHopHtlc( - hash, amt, testInvoiceExpiry, testCurrentHeight, hodlChan, + hash, amt, testInvoiceExpiry, testCurrentHeight, hodlChan, nil, ) if err != nil { t.Fatal("expected settlement of a canceled invoice to succeed") @@ -381,6 +382,7 @@ func TestHoldInvoice(t *testing.T) { // should be possible. event, err := registry.NotifyExitHopHtlc( hash, amtPaid, testInvoiceExpiry, testCurrentHeight, hodlChan, + nil, ) if err != nil { t.Fatalf("expected settle to succeed but got %v", err) @@ -392,6 +394,7 @@ func TestHoldInvoice(t *testing.T) { // Test idempotency. event, err = registry.NotifyExitHopHtlc( hash, amtPaid, testInvoiceExpiry, testCurrentHeight, hodlChan, + nil, ) if err != nil { t.Fatalf("expected settle to succeed but got %v", err) @@ -487,7 +490,7 @@ func TestUnknownInvoice(t *testing.T) { hodlChan := make(chan interface{}) amt := lnwire.MilliSatoshi(100000) _, err := registry.NotifyExitHopHtlc( - hash, amt, testInvoiceExpiry, testCurrentHeight, hodlChan, + hash, amt, testInvoiceExpiry, testCurrentHeight, hodlChan, nil, ) if err != channeldb.ErrInvoiceNotFound { t.Fatal("expected invoice not found error")