lnd.xprv/lnrpc/routerrpc/forward_interceptor.go

222 lines
6.8 KiB
Go

package routerrpc
import (
"errors"
"fmt"
"sync"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrFwdNotExists is an error returned when the caller tries to resolve
// a forward that doesn't exist anymore.
ErrFwdNotExists = errors.New("forward does not exist")
// ErrMissingPreimage is an error returned when the caller tries to settle
// a forward and doesn't provide a preimage.
ErrMissingPreimage = errors.New("missing preimage")
)
// forwardInterceptor is a helper struct that handles the lifecycle of an rpc
// interceptor streaming session.
// It is created when the stream opens and disconnects when the stream closes.
type forwardInterceptor struct {
// server is the Server reference
server *Server
// holdForwards is a map of current hold forwards and their corresponding
// ForwardResolver.
holdForwards map[channeldb.CircuitKey]htlcswitch.InterceptedForward
// stream is the bidirectional RPC stream
stream Router_HtlcInterceptorServer
// quit is a channel that is closed when this forwardInterceptor is shutting
// down.
quit chan struct{}
// intercepted is where we stream all intercepted packets coming from
// the switch.
intercepted chan htlcswitch.InterceptedForward
wg sync.WaitGroup
}
// newForwardInterceptor creates a new forwardInterceptor.
func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer) *forwardInterceptor {
return &forwardInterceptor{
server: server,
stream: stream,
holdForwards: make(
map[channeldb.CircuitKey]htlcswitch.InterceptedForward),
quit: make(chan struct{}),
intercepted: make(chan htlcswitch.InterceptedForward),
}
}
// run sends the intercepted packets to the client and receives the
// corersponding responses. On one hand it regsitered itself as an interceptor
// that receives the switch packets and on the other hand launches a go routine
// to read from the client stream.
// To coordinate all this and make sure it is safe for concurrent access all
// packets are sent to the main where they are handled.
func (r *forwardInterceptor) run() error {
// make sure we disconnect and resolves all remaining packets if any.
defer r.onDisconnect()
// Register our interceptor so we receive all forwarded packets.
interceptableForwarder := r.server.cfg.RouterBackend.InterceptableForwarder
interceptableForwarder.SetInterceptor(r.onIntercept)
defer interceptableForwarder.SetInterceptor(nil)
// start a go routine that reads client resolutions.
errChan := make(chan error)
resolutionRequests := make(chan *ForwardHtlcInterceptResponse)
r.wg.Add(1)
go r.readClientResponses(resolutionRequests, errChan)
// run the main loop that synchronizes both sides input into one go routine.
for {
select {
case intercepted := <-r.intercepted:
log.Tracef("sending intercepted packet to client %v", intercepted)
// in case we couldn't forward we exit the loop and drain the
// current interceptor as this indicates on a connection problem.
if err := r.holdAndForwardToClient(intercepted); err != nil {
return err
}
case resolution := <-resolutionRequests:
log.Tracef("resolving intercepted packet %v", resolution)
// in case we couldn't resolve we just add a log line since this
// does not indicate on any connection problem.
if err := r.resolveFromClient(resolution); err != nil {
log.Warnf("client resolution of intercepted "+
"packet failed %v", err)
}
case err := <-errChan:
return err
case <-r.server.quit:
return nil
}
}
}
// onIntercept is the function that is called by the switch for every forwarded
// packet. Our interceptor makes sure we hold the packet and then signal to the
// main loop to handle the packet. We only return true if we were able
// to deliver the packet to the main loop.
func (r *forwardInterceptor) onIntercept(p htlcswitch.InterceptedForward) bool {
select {
case r.intercepted <- p:
return true
case <-r.quit:
return false
case <-r.server.quit:
return false
}
}
func (r *forwardInterceptor) readClientResponses(
resolutionChan chan *ForwardHtlcInterceptResponse, errChan chan error) {
defer r.wg.Done()
for {
resp, err := r.stream.Recv()
if err != nil {
errChan <- err
return
}
// Now that we have the response from the RPC client, send it to
// the responses chan.
select {
case resolutionChan <- resp:
case <-r.quit:
return
case <-r.server.quit:
return
}
}
}
// holdAndForwardToClient forwards the intercepted htlc to the client.
func (r *forwardInterceptor) holdAndForwardToClient(
forward htlcswitch.InterceptedForward) error {
htlc := forward.Packet()
inKey := htlc.IncomingCircuit
// First hold the forward, then send to client.
r.holdForwards[inKey] = forward
interceptionRequest := &ForwardHtlcInterceptRequest{
IncomingCircuitKey: &CircuitKey{
ChanId: inKey.ChanID.ToUint64(),
HtlcId: inKey.HtlcID,
},
OutgoingRequestedChanId: htlc.OutgoingChanID.ToUint64(),
PaymentHash: htlc.Hash[:],
OutgoingAmountMsat: uint64(htlc.OutgoingAmount),
OutgoingExpiry: htlc.OutgoingExpiry,
IncomingAmountMsat: uint64(htlc.IncomingAmount),
IncomingExpiry: htlc.IncomingExpiry,
CustomRecords: htlc.CustomRecords,
}
return r.stream.Send(interceptionRequest)
}
// resolveFromClient handles a resolution arrived from the client.
func (r *forwardInterceptor) resolveFromClient(
in *ForwardHtlcInterceptResponse) error {
circuitKey := channeldb.CircuitKey{
ChanID: lnwire.NewShortChanIDFromInt(in.IncomingCircuitKey.ChanId),
HtlcID: in.IncomingCircuitKey.HtlcId,
}
var interceptedForward htlcswitch.InterceptedForward
interceptedForward, ok := r.holdForwards[circuitKey]
if !ok {
return ErrFwdNotExists
}
delete(r.holdForwards, circuitKey)
switch in.Action {
case ResolveHoldForwardAction_RESUME:
return interceptedForward.Resume()
case ResolveHoldForwardAction_FAIL:
return interceptedForward.Fail()
case ResolveHoldForwardAction_SETTLE:
if in.Preimage == nil {
return ErrMissingPreimage
}
preimage, err := lntypes.MakePreimage(in.Preimage)
if err != nil {
return err
}
return interceptedForward.Settle(preimage)
default:
return fmt.Errorf("unrecognized resolve action %v", in.Action)
}
}
// onDisconnect removes all previousely held forwards from
// the store. Before they are removed it ensure to resume as the default
// behavior.
func (r *forwardInterceptor) onDisconnect() {
// Then close the channel so all go routine will exit.
close(r.quit)
log.Infof("RPC interceptor disconnected, resolving held packets")
for key, forward := range r.holdForwards {
if err := forward.Resume(); err != nil {
log.Errorf("failed to resume hold forward %v", err)
}
delete(r.holdForwards, key)
}
r.wg.Wait()
}