package routerrpc import ( "errors" "fmt" "sync" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" ) var ( // ErrFwdNotExists is an error returned when the caller tries to resolve // a forward that doesn't exist anymore. ErrFwdNotExists = errors.New("forward does not exist") // ErrMissingPreimage is an error returned when the caller tries to settle // a forward and doesn't provide a preimage. ErrMissingPreimage = errors.New("missing preimage") ) // forwardInterceptor is a helper struct that handles the lifecycle of an rpc // interceptor streaming session. // It is created when the stream opens and disconnects when the stream closes. type forwardInterceptor struct { // server is the Server reference server *Server // holdForwards is a map of current hold forwards and their corresponding // ForwardResolver. holdForwards map[channeldb.CircuitKey]htlcswitch.InterceptedForward // stream is the bidirectional RPC stream stream Router_HtlcInterceptorServer // quit is a channel that is closed when this forwardInterceptor is shutting // down. quit chan struct{} // intercepted is where we stream all intercepted packets coming from // the switch. intercepted chan htlcswitch.InterceptedForward wg sync.WaitGroup } // newForwardInterceptor creates a new forwardInterceptor. func newForwardInterceptor(server *Server, stream Router_HtlcInterceptorServer) *forwardInterceptor { return &forwardInterceptor{ server: server, stream: stream, holdForwards: make( map[channeldb.CircuitKey]htlcswitch.InterceptedForward), quit: make(chan struct{}), intercepted: make(chan htlcswitch.InterceptedForward), } } // run sends the intercepted packets to the client and receives the // corersponding responses. On one hand it regsitered itself as an interceptor // that receives the switch packets and on the other hand launches a go routine // to read from the client stream. // To coordinate all this and make sure it is safe for concurrent access all // packets are sent to the main where they are handled. func (r *forwardInterceptor) run() error { // make sure we disconnect and resolves all remaining packets if any. defer r.onDisconnect() // Register our interceptor so we receive all forwarded packets. interceptableForwarder := r.server.cfg.RouterBackend.InterceptableForwarder interceptableForwarder.SetInterceptor(r.onIntercept) defer interceptableForwarder.SetInterceptor(nil) // start a go routine that reads client resolutions. errChan := make(chan error) resolutionRequests := make(chan *ForwardHtlcInterceptResponse) r.wg.Add(1) go r.readClientResponses(resolutionRequests, errChan) // run the main loop that synchronizes both sides input into one go routine. for { select { case intercepted := <-r.intercepted: log.Tracef("sending intercepted packet to client %v", intercepted) // in case we couldn't forward we exit the loop and drain the // current interceptor as this indicates on a connection problem. if err := r.holdAndForwardToClient(intercepted); err != nil { return err } case resolution := <-resolutionRequests: log.Tracef("resolving intercepted packet %v", resolution) // in case we couldn't resolve we just add a log line since this // does not indicate on any connection problem. if err := r.resolveFromClient(resolution); err != nil { log.Warnf("client resolution of intercepted "+ "packet failed %v", err) } case err := <-errChan: return err case <-r.server.quit: return nil } } } // onIntercept is the function that is called by the switch for every forwarded // packet. Our interceptor makes sure we hold the packet and then signal to the // main loop to handle the packet. We only return true if we were able // to deliver the packet to the main loop. func (r *forwardInterceptor) onIntercept(p htlcswitch.InterceptedForward) bool { select { case r.intercepted <- p: return true case <-r.quit: return false case <-r.server.quit: return false } } func (r *forwardInterceptor) readClientResponses( resolutionChan chan *ForwardHtlcInterceptResponse, errChan chan error) { defer r.wg.Done() for { resp, err := r.stream.Recv() if err != nil { errChan <- err return } // Now that we have the response from the RPC client, send it to // the responses chan. select { case resolutionChan <- resp: case <-r.quit: return case <-r.server.quit: return } } } // holdAndForwardToClient forwards the intercepted htlc to the client. func (r *forwardInterceptor) holdAndForwardToClient( forward htlcswitch.InterceptedForward) error { htlc := forward.Packet() inKey := htlc.IncomingCircuit // First hold the forward, then send to client. r.holdForwards[inKey] = forward interceptionRequest := &ForwardHtlcInterceptRequest{ IncomingCircuitKey: &CircuitKey{ ChanId: inKey.ChanID.ToUint64(), HtlcId: inKey.HtlcID, }, OutgoingRequestedChanId: htlc.OutgoingChanID.ToUint64(), PaymentHash: htlc.Hash[:], OutgoingAmountMsat: uint64(htlc.OutgoingAmount), OutgoingExpiry: htlc.OutgoingExpiry, IncomingAmountMsat: uint64(htlc.IncomingAmount), IncomingExpiry: htlc.IncomingExpiry, CustomRecords: htlc.CustomRecords, } return r.stream.Send(interceptionRequest) } // resolveFromClient handles a resolution arrived from the client. func (r *forwardInterceptor) resolveFromClient( in *ForwardHtlcInterceptResponse) error { circuitKey := channeldb.CircuitKey{ ChanID: lnwire.NewShortChanIDFromInt(in.IncomingCircuitKey.ChanId), HtlcID: in.IncomingCircuitKey.HtlcId, } var interceptedForward htlcswitch.InterceptedForward interceptedForward, ok := r.holdForwards[circuitKey] if !ok { return ErrFwdNotExists } delete(r.holdForwards, circuitKey) switch in.Action { case ResolveHoldForwardAction_RESUME: return interceptedForward.Resume() case ResolveHoldForwardAction_FAIL: return interceptedForward.Fail() case ResolveHoldForwardAction_SETTLE: if in.Preimage == nil { return ErrMissingPreimage } preimage, err := lntypes.MakePreimage(in.Preimage) if err != nil { return err } return interceptedForward.Settle(preimage) default: return fmt.Errorf("unrecognized resolve action %v", in.Action) } } // onDisconnect removes all previousely held forwards from // the store. Before they are removed it ensure to resume as the default // behavior. func (r *forwardInterceptor) onDisconnect() { // Then close the channel so all go routine will exit. close(r.quit) log.Infof("RPC interceptor disconnected, resolving held packets") for key, forward := range r.holdForwards { if err := forward.Resume(); err != nil { log.Errorf("failed to resume hold forward %v", err) } delete(r.holdForwards, key) } r.wg.Wait() }