diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go new file mode 100644 index 00000000..cdf5bae0 --- /dev/null +++ b/htlcswitch/interceptable_switch.go @@ -0,0 +1,170 @@ +package htlcswitch + +import ( + "fmt" + "sync" + + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/htlcswitch/hop" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // ErrFwdNotExists is an error returned when the caller tries to resolve + // a forward that doesn't exist anymore. + ErrFwdNotExists = errors.New("forward does not exist") +) + +// InterceptableSwitch is an implementation of ForwardingSwitch interface. +// This implementation is used like a proxy that wraps the switch and +// intercepts forward requests. A reference to the Switch is held in order +// to communicate back the interception result where the options are: +// Resume - forwards the original request to the switch as is. +// Settle - routes UpdateFulfillHTLC to the originating link. +// Fail - routes UpdateFailHTLC to the originating link. +type InterceptableSwitch struct { + sync.RWMutex + + // htlcSwitch is the underline switch + htlcSwitch *Switch + + // fwdInterceptor is the callback that is called for each forward of + // an incoming htlc. It should return true if it is interested in handling + // it. + fwdInterceptor ForwardInterceptor +} + +// NewInterceptableSwitch returns an instance of InterceptableSwitch. +func NewInterceptableSwitch(s *Switch) *InterceptableSwitch { + return &InterceptableSwitch{htlcSwitch: s} +} + +// SetInterceptor sets the ForwardInterceptor to be used. +func (s *InterceptableSwitch) SetInterceptor( + interceptor ForwardInterceptor) { + + s.Lock() + defer s.Unlock() + s.fwdInterceptor = interceptor +} + +// ForwardPackets attempts to forward the batch of htlcs through the +// switch, any failed packets will be returned to the provided +// ChannelLink. The link's quit signal should be provided to allow +// cancellation of forwarding during link shutdown. +func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{}, + packets ...*htlcPacket) error { + + var interceptor ForwardInterceptor + s.Lock() + interceptor = s.fwdInterceptor + s.Unlock() + + // Optimize for the case we don't have an interceptor. + if interceptor == nil { + return s.htlcSwitch.ForwardPackets(linkQuit, packets...) + } + + var notIntercepted []*htlcPacket + for _, p := range packets { + if !s.interceptForward(p, interceptor, linkQuit) { + notIntercepted = append(notIntercepted, p) + } + } + return s.htlcSwitch.ForwardPackets(linkQuit, notIntercepted...) +} + +// interceptForward checks if there is any external interceptor interested in +// this packet. Currently only htlc type of UpdateAddHTLC that are forwarded +// are being checked for interception. It can be extended in the future given +// the right use case. +func (s *InterceptableSwitch) interceptForward(packet *htlcPacket, + interceptor ForwardInterceptor, linkQuit chan struct{}) bool { + + switch htlc := packet.htlc.(type) { + case *lnwire.UpdateAddHTLC: + // We are not interested in intercepting initated payments. + if packet.incomingChanID == hop.Source { + return false + } + + intercepted := &interceptedForward{ + linkQuit: linkQuit, + htlc: htlc, + packet: packet, + htlcSwitch: s.htlcSwitch, + } + + // If this htlc was intercepted, don't handle the forward. + return interceptor(intercepted) + default: + return false + } +} + +// interceptedForward implements the InterceptedForward interface. +// It is passed from the switch to external interceptors that are interested +// in holding forwards and resolve them manually. +type interceptedForward struct { + linkQuit chan struct{} + htlc *lnwire.UpdateAddHTLC + packet *htlcPacket + htlcSwitch *Switch +} + +// Packet returns the intercepted htlc packet. +func (f *interceptedForward) Packet() lnwire.UpdateAddHTLC { + return *f.htlc +} + +// CircuitKey returns the circuit key for the intercepted htlc. +func (f *interceptedForward) CircuitKey() channeldb.CircuitKey { + return channeldb.CircuitKey{ + ChanID: f.packet.incomingChanID, + HtlcID: f.packet.incomingHTLCID, + } +} + +// Resume resumes the default behavior as if the packet was not intercepted. +func (f *interceptedForward) Resume() error { + return f.htlcSwitch.ForwardPackets(f.linkQuit, f.packet) +} + +// Fail forward a failed packet to the switch. +func (f *interceptedForward) Fail() error { + reason, err := f.packet.obfuscator.EncryptFirstHop(lnwire.NewTemporaryChannelFailure(nil)) + if err != nil { + return fmt.Errorf("failed to encrypt failure reason %v", err) + } + return f.resolve(&lnwire.UpdateFailHTLC{ + Reason: reason, + }) +} + +// Settle forwards a settled packet to the switch. +func (f *interceptedForward) Settle(preimage lntypes.Preimage) error { + if !preimage.Matches(f.htlc.PaymentHash) { + return errors.New("preimage does not match hash") + } + return f.resolve(&lnwire.UpdateFulfillHTLC{ + PaymentPreimage: preimage, + }) +} + +// resolve is used for both Settle and Fail and forwards the message to the +// switch. +func (f *interceptedForward) resolve(message lnwire.Message) error { + pkt := &htlcPacket{ + incomingChanID: f.packet.incomingChanID, + incomingHTLCID: f.packet.incomingHTLCID, + outgoingChanID: f.packet.outgoingChanID, + outgoingHTLCID: f.packet.outgoingHTLCID, + isResolution: true, + circuit: f.packet.circuit, + htlc: message, + obfuscator: f.packet.obfuscator, + } + return f.htlcSwitch.mailOrchestrator.Deliver(pkt.incomingChanID, pkt) +} diff --git a/htlcswitch/interfaces.go b/htlcswitch/interfaces.go index f51f1b15..67b1a458 100644 --- a/htlcswitch/interfaces.go +++ b/htlcswitch/interfaces.go @@ -185,6 +185,46 @@ type TowerClient interface { BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, bool) error } +// InterceptableHtlcForwarder is the interface to set the interceptor +// implementation that intercepts htlc forwards. +type InterceptableHtlcForwarder interface { + // SetInterceptor sets a ForwardInterceptor. + SetInterceptor(interceptor ForwardInterceptor) +} + +// ForwardInterceptor is a function that is invoked from the switch for every +// incoming htlc that is intended to be forwarded. It is passed with the +// InterceptedForward that contains the information about the packet and a way +// to resolve it manually later in case it is held. +// The return value indicates if this handler will take control of this forward +// and resolve it later or let the switch execute its default behavior. +type ForwardInterceptor func(InterceptedForward) bool + +// InterceptedForward is passed to the ForwardInterceptor for every forwarded +// htlc. It contains all the information about the packet which accordingly +// the interceptor decides if to hold or not. +// In addition this interface allows a later resolution by calling either +// Resume, Settle or Fail. +type InterceptedForward interface { + // CircuitKey returns the intercepted packet. + CircuitKey() channeldb.CircuitKey + + // Packet returns the intercepted packet. + Packet() lnwire.UpdateAddHTLC + + // Resume notifies the intention to resume an existing hold forward. This + // basically means the caller wants to resume with the default behavior for + // this htlc which usually means forward it. + Resume() error + + // Settle notifies the intention to settle an existing hold + // forward with a given preimage. + Settle(lntypes.Preimage) error + + // Fails notifies the intention to fail an existing hold forward + Fail() error +} + // htlcNotifier is an interface which represents the input side of the // HtlcNotifier which htlc events are piped through. This interface is intended // to allow for mocking of the htlcNotifier in tests, so is unexported because diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index d2977f7d..a31fb749 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/btcsuite/btcutil" + "github.com/btcsuite/fastsha256" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch/hop" @@ -1679,6 +1680,9 @@ func testSkipIneligibleLinksMultiHopForward(t *testing.T, if err := s.ForwardPackets(nil, packet); err != nil { t.Fatal(err) } + + // We select from all links and extract the error if exists. + // The packet must be selected but we don't always expect a link error. var linkError *LinkError select { case p := <-aliceChannelLink.packets: @@ -3111,3 +3115,186 @@ func getThreeHopEvents(channels *clusterChannels, htlcID uint64, return aliceEvents, bobEvents, carolEvents } + +type mockForwardInterceptor struct { + intercepted InterceptedForward +} + +func (m *mockForwardInterceptor) InterceptForwardHtlc(intercepted InterceptedForward) bool { + + m.intercepted = intercepted + return true +} + +func (m *mockForwardInterceptor) settle(preimage lntypes.Preimage) error { + return m.intercepted.Settle(preimage) +} + +func (m *mockForwardInterceptor) fail() error { + return m.intercepted.Fail() +} + +func (m *mockForwardInterceptor) resume() error { + return m.intercepted.Resume() +} + +func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) { + if s.circuits.NumPending() != pending { + t.Fatal("wrong amount of half circuits") + } + if s.circuits.NumOpen() != opened { + t.Fatal("wrong amount of circuits") + } +} + +func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink, + expectReceive bool) { + + // Pull packet from targetLink link. + select { + case packet := <-targetLink.packets: + if !expectReceive { + t.Fatal("forward was intercepted, shouldn't land at bob link") + } else if err := targetLink.completeCircuit(packet); err != nil { + t.Fatalf("unable to complete payment circuit: %v", err) + } + + case <-time.After(time.Second): + if expectReceive { + t.Fatal("request was not propagated to destination") + } + } +} + +func TestSwitchHoldForward(t *testing.T) { + t.Parallel() + + chanID1, chanID2, aliceChanID, bobChanID := genIDs() + + alicePeer, err := newMockServer( + t, "alice", testStartingHeight, nil, testDefaultDelta, + ) + if err != nil { + t.Fatalf("unable to create alice server: %v", err) + } + bobPeer, err := newMockServer( + t, "bob", testStartingHeight, nil, testDefaultDelta, + ) + if err != nil { + t.Fatalf("unable to create bob server: %v", err) + } + + tempPath, err := ioutil.TempDir("", "circuitdb") + if err != nil { + t.Fatalf("unable to temporary path: %v", err) + } + + cdb, err := channeldb.Open(tempPath) + if err != nil { + t.Fatalf("unable to open channeldb: %v", err) + } + + s, err := initSwitchWithDB(testStartingHeight, cdb) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + + defer func() { + if err := s.Stop(); err != nil { + t.Fatalf(err.Error()) + } + }() + + aliceChannelLink := newMockChannelLink( + s, chanID1, aliceChanID, alicePeer, true, + ) + bobChannelLink := newMockChannelLink( + s, chanID2, bobChanID, bobPeer, true, + ) + if err := s.AddLink(aliceChannelLink); err != nil { + t.Fatalf("unable to add alice link: %v", err) + } + if err := s.AddLink(bobChannelLink); err != nil { + t.Fatalf("unable to add bob link: %v", err) + } + + // Create request which should be forwarded from Alice channel link to + // bob channel link. + preimage := [sha256.Size]byte{1} + rhash := fastsha256.Sum256(preimage[:]) + ogPacket := &htlcPacket{ + incomingChanID: aliceChannelLink.ShortChanID(), + incomingHTLCID: 0, + outgoingChanID: bobChannelLink.ShortChanID(), + obfuscator: NewMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + + forwardInterceptor := &mockForwardInterceptor{} + switchForwardInterceptor := NewInterceptableSwitch(s) + switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc) + linkQuit := make(chan struct{}) + + // Test resume a hold forward + assertNumCircuits(t, s, 0, 0) + if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { + t.Fatalf("can't forward htlc packet: %v", err) + } + assertNumCircuits(t, s, 0, 0) + assertOutgoingLinkReceive(t, bobChannelLink, false) + + if err := forwardInterceptor.resume(); err != nil { + t.Fatalf("failed to resume forward") + } + assertOutgoingLinkReceive(t, bobChannelLink, true) + assertNumCircuits(t, s, 1, 1) + + // settling the htlc to close the circuit. + settle := &htlcPacket{ + outgoingChanID: bobChannelLink.ShortChanID(), + outgoingHTLCID: 0, + amount: 1, + htlc: &lnwire.UpdateFulfillHTLC{ + PaymentPreimage: preimage, + }, + } + if err := switchForwardInterceptor.ForwardPackets(linkQuit, settle); err != nil { + t.Fatalf("can't forward htlc packet: %v", err) + } + assertOutgoingLinkReceive(t, aliceChannelLink, true) + assertNumCircuits(t, s, 0, 0) + + // Test failing a hold forward + if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { + t.Fatalf("can't forward htlc packet: %v", err) + } + assertNumCircuits(t, s, 0, 0) + assertOutgoingLinkReceive(t, bobChannelLink, false) + + if err := forwardInterceptor.fail(); err != nil { + t.Fatalf("failed to cancel forward %v", err) + } + assertOutgoingLinkReceive(t, bobChannelLink, false) + assertOutgoingLinkReceive(t, aliceChannelLink, true) + assertNumCircuits(t, s, 0, 0) + + // Test settling a hold forward + if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil { + t.Fatalf("can't forward htlc packet: %v", err) + } + assertNumCircuits(t, s, 0, 0) + assertOutgoingLinkReceive(t, bobChannelLink, false) + + if err := forwardInterceptor.settle(preimage); err != nil { + t.Fatal("failed to cancel forward") + } + assertOutgoingLinkReceive(t, bobChannelLink, false) + assertOutgoingLinkReceive(t, aliceChannelLink, true) + assertNumCircuits(t, s, 0, 0) +} diff --git a/peer.go b/peer.go index a43f254b..3ea12061 100644 --- a/peer.go +++ b/peer.go @@ -659,7 +659,7 @@ func (p *peer) addLink(chanPoint *wire.OutPoint, Registry: p.server.invoices, Switch: p.server.htlcSwitch, Circuits: p.server.htlcSwitch.CircuitModifier(), - ForwardPackets: p.server.htlcSwitch.ForwardPackets, + ForwardPackets: p.server.interceptableSwitch.ForwardPackets, FwrdingPolicy: *forwardingPolicy, FeeEstimator: p.server.cc.feeEstimator, PreimageCache: p.server.witnessBeacon, diff --git a/server.go b/server.go index e8d33320..633bafe4 100644 --- a/server.go +++ b/server.go @@ -207,6 +207,8 @@ type server struct { htlcSwitch *htlcswitch.Switch + interceptableSwitch *htlcswitch.InterceptableSwitch + invoices *invoices.InvoiceRegistry channelNotifier *channelnotifier.ChannelNotifier @@ -515,6 +517,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, chanDB *channeldb.DB, if err != nil { return nil, err } + s.interceptableSwitch = htlcswitch.NewInterceptableSwitch(s.htlcSwitch) chanStatusMgrCfg := &netann.ChanStatusConfig{ ChanStatusSampleInterval: cfg.ChanStatusSampleInterval,