diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index 49d60ad0..78dec019 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -238,6 +238,14 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, if in.FinalCltvDelta != 0 { finalCLTVDelta = uint16(in.FinalCltvDelta) } + + // Do bounds checking without block padding so we don't give routes + // that will leave the router in a zombie payment state. + err = routing.ValidateCLTVLimit(cltvLimit, finalCLTVDelta, false) + if err != nil { + return nil, err + } + cltvLimit -= uint32(finalCLTVDelta) // Parse destination feature bits. @@ -860,6 +868,15 @@ func (r *RouterBackend) extractIntentFromSendRequest( payIntent.DestFeatures = features } + // Do bounds checking with the block padding so the router isn't + // left with a zombie payment in case the user messes up. + err = routing.ValidateCLTVLimit( + payIntent.CltvLimit, payIntent.FinalCLTVDelta, true, + ) + if err != nil { + return nil, err + } + // Check for disallowed payments to self. if !rpcPayReq.AllowSelfPayment && payIntent.Target == r.SelfNode { return nil, errors.New("self-payments not allowed") diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index f92a6f15..26a44cbb 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -37,17 +37,22 @@ var ( // and passed onto path finding. func TestQueryRoutes(t *testing.T) { t.Run("no mission control", func(t *testing.T) { - testQueryRoutes(t, false, false) + testQueryRoutes(t, false, false, true) }) t.Run("no mission control and msat", func(t *testing.T) { - testQueryRoutes(t, false, true) + testQueryRoutes(t, false, true, true) }) t.Run("with mission control", func(t *testing.T) { - testQueryRoutes(t, true, false) + testQueryRoutes(t, true, false, true) + }) + t.Run("no mission control bad cltv limit", func(t *testing.T) { + testQueryRoutes(t, false, false, false) }) } -func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool) { +func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool, + setTimelock bool) { + ignoreNodeBytes, err := hex.DecodeString(ignoreNodeKey) if err != nil { t.Fatal(err) @@ -207,7 +212,21 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool) { }, } + // If this is set, we'll populate MaxTotalTimelock. If this is not set, + // the test will fail as CltvLimit will be 0. + if setTimelock { + backend.MaxTotalTimelock = 1000 + } + resp, err := backend.QueryRoutes(context.Background(), request) + + // If no MaxTotalTimelock was set for the QueryRoutes request, make + // sure an error was returned. + if !setTimelock { + require.NotEmpty(t, err) + return + } + if err != nil { t.Fatal(err) } diff --git a/routing/payment_session.go b/routing/payment_session.go index ebaef74d..9dc280fa 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -14,6 +14,22 @@ import ( // to prevent an HTLC being failed if some blocks are mined while it's in-flight. const BlockPadding uint16 = 3 +// ValidateCLTVLimit is a helper function that validates that the cltv limit is +// greater than the final cltv delta parameter, optionally including the +// BlockPadding in this calculation. +func ValidateCLTVLimit(limit uint32, delta uint16, includePad bool) error { + if includePad { + delta += BlockPadding + } + + if limit <= uint32(delta) { + return fmt.Errorf("cltv limit %v should be greater than %v", + limit, delta) + } + + return nil +} + // noRouteError encodes a non-critical error encountered during path finding. type noRouteError uint8 diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 2d1eb999..2fa103d3 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -6,8 +6,70 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" + "github.com/stretchr/testify/require" ) +func TestValidateCLTVLimit(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + cltvLimit uint32 + finalCltvDelta uint16 + includePadding bool + expectError bool + }{ + { + name: "bad limit with padding", + cltvLimit: uint32(103), + finalCltvDelta: uint16(100), + includePadding: true, + expectError: true, + }, + { + name: "good limit with padding", + cltvLimit: uint32(104), + finalCltvDelta: uint16(100), + includePadding: true, + expectError: false, + }, + { + name: "bad limit no padding", + cltvLimit: uint32(100), + finalCltvDelta: uint16(100), + includePadding: false, + expectError: true, + }, + { + name: "good limit no padding", + cltvLimit: uint32(101), + finalCltvDelta: uint16(100), + includePadding: false, + expectError: false, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + success := t.Run(testCase.name, func(t *testing.T) { + err := ValidateCLTVLimit( + testCase.cltvLimit, testCase.finalCltvDelta, + testCase.includePadding, + ) + + if testCase.expectError { + require.NotEmpty(t, err) + } else { + require.NoError(t, err) + } + }) + if !success { + break + } + } +} + func TestRequestRoute(t *testing.T) { const ( height = 10 diff --git a/rpcserver.go b/rpcserver.go index 6338bb7b..c425c927 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -4374,6 +4374,15 @@ func (r *rpcServer) extractPaymentIntent(rpcPayReq *rpcPaymentRequest) (rpcPayme payIntent.cltvDelta = uint16(r.cfg.Bitcoin.TimeLockDelta) } + // Do bounds checking with the block padding so the router isn't left + // with a zombie payment in case the user messes up. + err = routing.ValidateCLTVLimit( + payIntent.cltvLimit, payIntent.cltvDelta, true, + ) + if err != nil { + return payIntent, err + } + // If the user is manually specifying payment details, then the payment // hash may be encoded as a string. switch {