htlcswitch: introducing interceptable switch.
In this commit we implement a wrapper arround the switch, called InterceptableSwitch. This kind of wrapper behaves like a proxy which intercepts forwarded packets and allows an external interceptor to signal if it is interested to hold this forward and resolve it manually later or let the switch execute its default behavior. This infrastructure allows the RPC layer to expose interceptor registration API to the user and by that enable the implementation of custom routing behavior.
This commit is contained in:
parent
1a6701122c
commit
0f50d8b2ed
170
htlcswitch/interceptable_switch.go
Normal file
170
htlcswitch/interceptable_switch.go
Normal file
@ -0,0 +1,170 @@
|
||||
package htlcswitch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/go-errors/errors"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrFwdNotExists is an error returned when the caller tries to resolve
|
||||
// a forward that doesn't exist anymore.
|
||||
ErrFwdNotExists = errors.New("forward does not exist")
|
||||
)
|
||||
|
||||
// InterceptableSwitch is an implementation of ForwardingSwitch interface.
|
||||
// This implementation is used like a proxy that wraps the switch and
|
||||
// intercepts forward requests. A reference to the Switch is held in order
|
||||
// to communicate back the interception result where the options are:
|
||||
// Resume - forwards the original request to the switch as is.
|
||||
// Settle - routes UpdateFulfillHTLC to the originating link.
|
||||
// Fail - routes UpdateFailHTLC to the originating link.
|
||||
type InterceptableSwitch struct {
|
||||
sync.RWMutex
|
||||
|
||||
// htlcSwitch is the underline switch
|
||||
htlcSwitch *Switch
|
||||
|
||||
// fwdInterceptor is the callback that is called for each forward of
|
||||
// an incoming htlc. It should return true if it is interested in handling
|
||||
// it.
|
||||
fwdInterceptor ForwardInterceptor
|
||||
}
|
||||
|
||||
// NewInterceptableSwitch returns an instance of InterceptableSwitch.
|
||||
func NewInterceptableSwitch(s *Switch) *InterceptableSwitch {
|
||||
return &InterceptableSwitch{htlcSwitch: s}
|
||||
}
|
||||
|
||||
// SetInterceptor sets the ForwardInterceptor to be used.
|
||||
func (s *InterceptableSwitch) SetInterceptor(
|
||||
interceptor ForwardInterceptor) {
|
||||
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.fwdInterceptor = interceptor
|
||||
}
|
||||
|
||||
// ForwardPackets attempts to forward the batch of htlcs through the
|
||||
// switch, any failed packets will be returned to the provided
|
||||
// ChannelLink. The link's quit signal should be provided to allow
|
||||
// cancellation of forwarding during link shutdown.
|
||||
func (s *InterceptableSwitch) ForwardPackets(linkQuit chan struct{},
|
||||
packets ...*htlcPacket) error {
|
||||
|
||||
var interceptor ForwardInterceptor
|
||||
s.Lock()
|
||||
interceptor = s.fwdInterceptor
|
||||
s.Unlock()
|
||||
|
||||
// Optimize for the case we don't have an interceptor.
|
||||
if interceptor == nil {
|
||||
return s.htlcSwitch.ForwardPackets(linkQuit, packets...)
|
||||
}
|
||||
|
||||
var notIntercepted []*htlcPacket
|
||||
for _, p := range packets {
|
||||
if !s.interceptForward(p, interceptor, linkQuit) {
|
||||
notIntercepted = append(notIntercepted, p)
|
||||
}
|
||||
}
|
||||
return s.htlcSwitch.ForwardPackets(linkQuit, notIntercepted...)
|
||||
}
|
||||
|
||||
// interceptForward checks if there is any external interceptor interested in
|
||||
// this packet. Currently only htlc type of UpdateAddHTLC that are forwarded
|
||||
// are being checked for interception. It can be extended in the future given
|
||||
// the right use case.
|
||||
func (s *InterceptableSwitch) interceptForward(packet *htlcPacket,
|
||||
interceptor ForwardInterceptor, linkQuit chan struct{}) bool {
|
||||
|
||||
switch htlc := packet.htlc.(type) {
|
||||
case *lnwire.UpdateAddHTLC:
|
||||
// We are not interested in intercepting initated payments.
|
||||
if packet.incomingChanID == hop.Source {
|
||||
return false
|
||||
}
|
||||
|
||||
intercepted := &interceptedForward{
|
||||
linkQuit: linkQuit,
|
||||
htlc: htlc,
|
||||
packet: packet,
|
||||
htlcSwitch: s.htlcSwitch,
|
||||
}
|
||||
|
||||
// If this htlc was intercepted, don't handle the forward.
|
||||
return interceptor(intercepted)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// interceptedForward implements the InterceptedForward interface.
|
||||
// It is passed from the switch to external interceptors that are interested
|
||||
// in holding forwards and resolve them manually.
|
||||
type interceptedForward struct {
|
||||
linkQuit chan struct{}
|
||||
htlc *lnwire.UpdateAddHTLC
|
||||
packet *htlcPacket
|
||||
htlcSwitch *Switch
|
||||
}
|
||||
|
||||
// Packet returns the intercepted htlc packet.
|
||||
func (f *interceptedForward) Packet() lnwire.UpdateAddHTLC {
|
||||
return *f.htlc
|
||||
}
|
||||
|
||||
// CircuitKey returns the circuit key for the intercepted htlc.
|
||||
func (f *interceptedForward) CircuitKey() channeldb.CircuitKey {
|
||||
return channeldb.CircuitKey{
|
||||
ChanID: f.packet.incomingChanID,
|
||||
HtlcID: f.packet.incomingHTLCID,
|
||||
}
|
||||
}
|
||||
|
||||
// Resume resumes the default behavior as if the packet was not intercepted.
|
||||
func (f *interceptedForward) Resume() error {
|
||||
return f.htlcSwitch.ForwardPackets(f.linkQuit, f.packet)
|
||||
}
|
||||
|
||||
// Fail forward a failed packet to the switch.
|
||||
func (f *interceptedForward) Fail() error {
|
||||
reason, err := f.packet.obfuscator.EncryptFirstHop(lnwire.NewTemporaryChannelFailure(nil))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt failure reason %v", err)
|
||||
}
|
||||
return f.resolve(&lnwire.UpdateFailHTLC{
|
||||
Reason: reason,
|
||||
})
|
||||
}
|
||||
|
||||
// Settle forwards a settled packet to the switch.
|
||||
func (f *interceptedForward) Settle(preimage lntypes.Preimage) error {
|
||||
if !preimage.Matches(f.htlc.PaymentHash) {
|
||||
return errors.New("preimage does not match hash")
|
||||
}
|
||||
return f.resolve(&lnwire.UpdateFulfillHTLC{
|
||||
PaymentPreimage: preimage,
|
||||
})
|
||||
}
|
||||
|
||||
// resolve is used for both Settle and Fail and forwards the message to the
|
||||
// switch.
|
||||
func (f *interceptedForward) resolve(message lnwire.Message) error {
|
||||
pkt := &htlcPacket{
|
||||
incomingChanID: f.packet.incomingChanID,
|
||||
incomingHTLCID: f.packet.incomingHTLCID,
|
||||
outgoingChanID: f.packet.outgoingChanID,
|
||||
outgoingHTLCID: f.packet.outgoingHTLCID,
|
||||
isResolution: true,
|
||||
circuit: f.packet.circuit,
|
||||
htlc: message,
|
||||
obfuscator: f.packet.obfuscator,
|
||||
}
|
||||
return f.htlcSwitch.mailOrchestrator.Deliver(pkt.incomingChanID, pkt)
|
||||
}
|
@ -185,6 +185,46 @@ type TowerClient interface {
|
||||
BackupState(*lnwire.ChannelID, *lnwallet.BreachRetribution, bool) error
|
||||
}
|
||||
|
||||
// InterceptableHtlcForwarder is the interface to set the interceptor
|
||||
// implementation that intercepts htlc forwards.
|
||||
type InterceptableHtlcForwarder interface {
|
||||
// SetInterceptor sets a ForwardInterceptor.
|
||||
SetInterceptor(interceptor ForwardInterceptor)
|
||||
}
|
||||
|
||||
// ForwardInterceptor is a function that is invoked from the switch for every
|
||||
// incoming htlc that is intended to be forwarded. It is passed with the
|
||||
// InterceptedForward that contains the information about the packet and a way
|
||||
// to resolve it manually later in case it is held.
|
||||
// The return value indicates if this handler will take control of this forward
|
||||
// and resolve it later or let the switch execute its default behavior.
|
||||
type ForwardInterceptor func(InterceptedForward) bool
|
||||
|
||||
// InterceptedForward is passed to the ForwardInterceptor for every forwarded
|
||||
// htlc. It contains all the information about the packet which accordingly
|
||||
// the interceptor decides if to hold or not.
|
||||
// In addition this interface allows a later resolution by calling either
|
||||
// Resume, Settle or Fail.
|
||||
type InterceptedForward interface {
|
||||
// CircuitKey returns the intercepted packet.
|
||||
CircuitKey() channeldb.CircuitKey
|
||||
|
||||
// Packet returns the intercepted packet.
|
||||
Packet() lnwire.UpdateAddHTLC
|
||||
|
||||
// Resume notifies the intention to resume an existing hold forward. This
|
||||
// basically means the caller wants to resume with the default behavior for
|
||||
// this htlc which usually means forward it.
|
||||
Resume() error
|
||||
|
||||
// Settle notifies the intention to settle an existing hold
|
||||
// forward with a given preimage.
|
||||
Settle(lntypes.Preimage) error
|
||||
|
||||
// Fails notifies the intention to fail an existing hold forward
|
||||
Fail() error
|
||||
}
|
||||
|
||||
// htlcNotifier is an interface which represents the input side of the
|
||||
// HtlcNotifier which htlc events are piped through. This interface is intended
|
||||
// to allow for mocking of the htlcNotifier in tests, so is unexported because
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/btcsuite/fastsha256"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
@ -1679,6 +1680,9 @@ func testSkipIneligibleLinksMultiHopForward(t *testing.T,
|
||||
if err := s.ForwardPackets(nil, packet); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// We select from all links and extract the error if exists.
|
||||
// The packet must be selected but we don't always expect a link error.
|
||||
var linkError *LinkError
|
||||
select {
|
||||
case p := <-aliceChannelLink.packets:
|
||||
@ -3111,3 +3115,186 @@ func getThreeHopEvents(channels *clusterChannels, htlcID uint64,
|
||||
|
||||
return aliceEvents, bobEvents, carolEvents
|
||||
}
|
||||
|
||||
type mockForwardInterceptor struct {
|
||||
intercepted InterceptedForward
|
||||
}
|
||||
|
||||
func (m *mockForwardInterceptor) InterceptForwardHtlc(intercepted InterceptedForward) bool {
|
||||
|
||||
m.intercepted = intercepted
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockForwardInterceptor) settle(preimage lntypes.Preimage) error {
|
||||
return m.intercepted.Settle(preimage)
|
||||
}
|
||||
|
||||
func (m *mockForwardInterceptor) fail() error {
|
||||
return m.intercepted.Fail()
|
||||
}
|
||||
|
||||
func (m *mockForwardInterceptor) resume() error {
|
||||
return m.intercepted.Resume()
|
||||
}
|
||||
|
||||
func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
|
||||
if s.circuits.NumPending() != pending {
|
||||
t.Fatal("wrong amount of half circuits")
|
||||
}
|
||||
if s.circuits.NumOpen() != opened {
|
||||
t.Fatal("wrong amount of circuits")
|
||||
}
|
||||
}
|
||||
|
||||
func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink,
|
||||
expectReceive bool) {
|
||||
|
||||
// Pull packet from targetLink link.
|
||||
select {
|
||||
case packet := <-targetLink.packets:
|
||||
if !expectReceive {
|
||||
t.Fatal("forward was intercepted, shouldn't land at bob link")
|
||||
} else if err := targetLink.completeCircuit(packet); err != nil {
|
||||
t.Fatalf("unable to complete payment circuit: %v", err)
|
||||
}
|
||||
|
||||
case <-time.After(time.Second):
|
||||
if expectReceive {
|
||||
t.Fatal("request was not propagated to destination")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSwitchHoldForward(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
|
||||
|
||||
alicePeer, err := newMockServer(
|
||||
t, "alice", testStartingHeight, nil, testDefaultDelta,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create alice server: %v", err)
|
||||
}
|
||||
bobPeer, err := newMockServer(
|
||||
t, "bob", testStartingHeight, nil, testDefaultDelta,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create bob server: %v", err)
|
||||
}
|
||||
|
||||
tempPath, err := ioutil.TempDir("", "circuitdb")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to temporary path: %v", err)
|
||||
}
|
||||
|
||||
cdb, err := channeldb.Open(tempPath)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to open channeldb: %v", err)
|
||||
}
|
||||
|
||||
s, err := initSwitchWithDB(testStartingHeight, cdb)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to init switch: %v", err)
|
||||
}
|
||||
if err := s.Start(); err != nil {
|
||||
t.Fatalf("unable to start switch: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := s.Stop(); err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
aliceChannelLink := newMockChannelLink(
|
||||
s, chanID1, aliceChanID, alicePeer, true,
|
||||
)
|
||||
bobChannelLink := newMockChannelLink(
|
||||
s, chanID2, bobChanID, bobPeer, true,
|
||||
)
|
||||
if err := s.AddLink(aliceChannelLink); err != nil {
|
||||
t.Fatalf("unable to add alice link: %v", err)
|
||||
}
|
||||
if err := s.AddLink(bobChannelLink); err != nil {
|
||||
t.Fatalf("unable to add bob link: %v", err)
|
||||
}
|
||||
|
||||
// Create request which should be forwarded from Alice channel link to
|
||||
// bob channel link.
|
||||
preimage := [sha256.Size]byte{1}
|
||||
rhash := fastsha256.Sum256(preimage[:])
|
||||
ogPacket := &htlcPacket{
|
||||
incomingChanID: aliceChannelLink.ShortChanID(),
|
||||
incomingHTLCID: 0,
|
||||
outgoingChanID: bobChannelLink.ShortChanID(),
|
||||
obfuscator: NewMockObfuscator(),
|
||||
htlc: &lnwire.UpdateAddHTLC{
|
||||
PaymentHash: rhash,
|
||||
Amount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
forwardInterceptor := &mockForwardInterceptor{}
|
||||
switchForwardInterceptor := NewInterceptableSwitch(s)
|
||||
switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
|
||||
linkQuit := make(chan struct{})
|
||||
|
||||
// Test resume a hold forward
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
|
||||
t.Fatalf("can't forward htlc packet: %v", err)
|
||||
}
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
|
||||
if err := forwardInterceptor.resume(); err != nil {
|
||||
t.Fatalf("failed to resume forward")
|
||||
}
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, true)
|
||||
assertNumCircuits(t, s, 1, 1)
|
||||
|
||||
// settling the htlc to close the circuit.
|
||||
settle := &htlcPacket{
|
||||
outgoingChanID: bobChannelLink.ShortChanID(),
|
||||
outgoingHTLCID: 0,
|
||||
amount: 1,
|
||||
htlc: &lnwire.UpdateFulfillHTLC{
|
||||
PaymentPreimage: preimage,
|
||||
},
|
||||
}
|
||||
if err := switchForwardInterceptor.ForwardPackets(linkQuit, settle); err != nil {
|
||||
t.Fatalf("can't forward htlc packet: %v", err)
|
||||
}
|
||||
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
|
||||
// Test failing a hold forward
|
||||
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
|
||||
t.Fatalf("can't forward htlc packet: %v", err)
|
||||
}
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
|
||||
if err := forwardInterceptor.fail(); err != nil {
|
||||
t.Fatalf("failed to cancel forward %v", err)
|
||||
}
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
|
||||
// Test settling a hold forward
|
||||
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
|
||||
t.Fatalf("can't forward htlc packet: %v", err)
|
||||
}
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
|
||||
if err := forwardInterceptor.settle(preimage); err != nil {
|
||||
t.Fatal("failed to cancel forward")
|
||||
}
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
}
|
||||
|
2
peer.go
2
peer.go
@ -659,7 +659,7 @@ func (p *peer) addLink(chanPoint *wire.OutPoint,
|
||||
Registry: p.server.invoices,
|
||||
Switch: p.server.htlcSwitch,
|
||||
Circuits: p.server.htlcSwitch.CircuitModifier(),
|
||||
ForwardPackets: p.server.htlcSwitch.ForwardPackets,
|
||||
ForwardPackets: p.server.interceptableSwitch.ForwardPackets,
|
||||
FwrdingPolicy: *forwardingPolicy,
|
||||
FeeEstimator: p.server.cc.feeEstimator,
|
||||
PreimageCache: p.server.witnessBeacon,
|
||||
|
@ -207,6 +207,8 @@ type server struct {
|
||||
|
||||
htlcSwitch *htlcswitch.Switch
|
||||
|
||||
interceptableSwitch *htlcswitch.InterceptableSwitch
|
||||
|
||||
invoices *invoices.InvoiceRegistry
|
||||
|
||||
channelNotifier *channelnotifier.ChannelNotifier
|
||||
@ -515,6 +517,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, chanDB *channeldb.DB,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.interceptableSwitch = htlcswitch.NewInterceptableSwitch(s.htlcSwitch)
|
||||
|
||||
chanStatusMgrCfg := &netann.ChanStatusConfig{
|
||||
ChanStatusSampleInterval: cfg.ChanStatusSampleInterval,
|
||||
|
Loading…
Reference in New Issue
Block a user