diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 8f94a9a0..a13172bc 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -9,12 +9,11 @@ import ( "sync/atomic" "time" + "github.com/boltdb/bolt" "github.com/btcsuite/btcd/btcec" - "github.com/coreos/bbolt" - "github.com/davecgh/go-spew/spew" - "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" @@ -204,6 +203,9 @@ type Switch struct { paymentSequencer Sequencer + // control provides verification of sending htlc mesages + control ControlTower + // circuits is storage for payment circuits which are used to // forward the settle/fail htlc updates back to the add htlc initiator. circuits CircuitMap @@ -289,6 +291,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { cfg: &cfg, circuits: circuitMap, paymentSequencer: sequencer, + control: NewPaymentControl(cfg.DB), linkIndex: make(map[lnwire.ChannelID]ChannelLink), mailOrchestrator: newMailOrchestrator(), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), @@ -344,6 +347,11 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, htlc *lnwire.UpdateAddHTLC, deobfuscator ErrorDecrypter) ([sha256.Size]byte, error) { + // Verify message by ControlTower implementation. + if err := s.control.CheckSend(htlc); err != nil { + return zeroPreimage, err + } + // Create payment and add to the map of payment in order later to be // able to retrieve it and return response to the user. payment := &pendingPayment{ @@ -376,6 +384,10 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, if err := s.forward(packet); err != nil { s.removePendingPayment(paymentID) + if err := s.control.Fail(htlc.PaymentHash); err != nil { + return zeroPreimage, err + } + return zeroPreimage, err } @@ -837,6 +849,10 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { payment.preimage <- htlc.PaymentPreimage s.removePendingPayment(pkt.incomingHTLCID) + if err := s.control.Success(pkt.circuit.PaymentHash); err != nil { + return err + } + // We've just received a fail update which means we can finalize the // user payment and return fail response. case *lnwire.UpdateFailHTLC: @@ -901,6 +917,10 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, FailureMessage: lnwire.FailPermanentChannelFailure{}, } + if err := s.control.Fail(pkt.circuit.PaymentHash); err != nil { + log.Error(err) + } + // A regular multi-hop payment error that we'll need to // decrypt. default: @@ -917,6 +937,10 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, ExtraMsg: userErr, FailureMessage: lnwire.NewTemporaryChannelFailure(nil), } + + if err := s.control.Fail(pkt.circuit.PaymentHash); err != nil { + log.Error(err) + } } }