diff --git a/channeldb/control_tower.go b/channeldb/control_tower.go index 82e3c1ea..ae6c7e7c 100644 --- a/channeldb/control_tower.go +++ b/channeldb/control_tower.go @@ -57,23 +57,13 @@ type ControlTower interface { // paymentControl is persistent implementation of ControlTower to restrict // double payment sending. type paymentControl struct { - strict bool - db *DB } -// NewPaymentControl creates a new instance of the paymentControl. The strict -// flag indicates whether the controller should require "strict" state -// transitions, which would be otherwise intolerant to older databases that may -// already have duplicate payments to the same payment hash. It should be -// enabled only after sufficient checks have been made to ensure the db does not -// contain such payments. In the meantime, non-strict mode enforces a superset -// of the state transitions that prevent additional payments to a given payment -// hash from being added. -func NewPaymentControl(strict bool, db *DB) ControlTower { +// NewPaymentControl creates a new instance of the paymentControl. +func NewPaymentControl(db *DB) ControlTower { return &paymentControl{ - strict: strict, - db: db, + db: db, } } @@ -96,21 +86,20 @@ func (p *paymentControl) ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error { switch paymentStatus { + // It is safe to reattempt a payment if we know that we haven't + // left one in flight. Since this one is grounded or failed, + // transition the payment status to InFlight to prevent others. case StatusGrounded: - // It is safe to reattempt a payment if we know that we - // haven't left one in flight. Since this one is - // grounded or failed, transition the payment status - // to InFlight to prevent others. return bucket.Put(paymentStatusKey, StatusInFlight.Bytes()) + // We already have an InFlight payment on the network. We will + // disallow any more payment until a response is received. case StatusInFlight: - // We already have an InFlight payment on the network. We will - // disallow any more payment until a response is received. takeoffErr = ErrPaymentInFlight + // We've already completed a payment to this payment hash, + // forbid the switch from sending another. case StatusCompleted: - // We've already completed a payment to this payment hash, - // forbid the switch from sending another. takeoffErr = ErrAlreadyPaid default: @@ -146,27 +135,20 @@ func (p *paymentControl) Success(paymentHash [32]byte) error { switch { - case paymentStatus == StatusGrounded && p.strict: - // Our records show the payment as still being grounded, - // meaning it never should have left the switch. + // Our records show the payment as still being grounded, + // meaning it never should have left the switch. + case paymentStatus == StatusGrounded: updateErr = ErrPaymentNotInitiated - case paymentStatus == StatusGrounded && !p.strict: - // Though our records show the payment as still being - // grounded, meaning it never should have left the - // switch, we permit this transition in non-strict mode - // to handle inconsistent db states. - fallthrough - + // A successful response was received for an InFlight payment, + // mark it as completed to prevent sending to this payment hash + // again. case paymentStatus == StatusInFlight: - // A successful response was received for an InFlight - // payment, mark it as completed to prevent sending to - // this payment hash again. return bucket.Put(paymentStatusKey, StatusCompleted.Bytes()) + // The payment was completed previously, alert the caller that + // this may be a duplicate call. case paymentStatus == StatusCompleted: - // The payment was completed previously, alert the - // caller that this may be a duplicate call. updateErr = ErrPaymentAlreadyCompleted default: @@ -201,29 +183,20 @@ func (p *paymentControl) Fail(paymentHash [32]byte) error { switch { - case paymentStatus == StatusGrounded && p.strict: - // Our records show the payment as still being grounded, - // meaning it never should have left the switch. + // Our records show the payment as still being grounded, + // meaning it never should have left the switch. + case paymentStatus == StatusGrounded: updateErr = ErrPaymentNotInitiated - case paymentStatus == StatusGrounded && !p.strict: - // Though our records show the payment as still being - // grounded, meaning it never should have left the - // switch, we permit this transition in non-strict mode - // to handle inconsistent db states. - fallthrough - + // A failed response was received for an InFlight payment, mark + // it as Failed to allow subsequent attempts. case paymentStatus == StatusInFlight: - // A failed response was received for an InFlight - // payment, mark it as Failed to allow subsequent - // attempts. return bucket.Put(paymentStatusKey, StatusGrounded.Bytes()) + // The payment was completed previously, and we are now + // reporting that it has failed. Leave the status as completed, + // but alert the user that something is wrong. case paymentStatus == StatusCompleted: - // The payment was completed previously, and we are now - // reporting that it has failed. Leave the status as - // completed, but alert the user that something is - // wrong. updateErr = ErrPaymentAlreadyCompleted default: diff --git a/channeldb/control_tower_test.go b/channeldb/control_tower_test.go index 42df8bf4..a70e3d11 100644 --- a/channeldb/control_tower_test.go +++ b/channeldb/control_tower_test.go @@ -49,41 +49,22 @@ func genHtlc() (*lnwire.UpdateAddHTLC, error) { return htlc, nil } -type paymentControlTestCase func(*testing.T, bool) +type paymentControlTestCase func(*testing.T) var paymentControlTests = []struct { name string - strict bool testcase paymentControlTestCase }{ { - name: "fail-strict", - strict: true, + name: "fail", testcase: testPaymentControlSwitchFail, }, { - name: "double-send-strict", - strict: true, + name: "double-send", testcase: testPaymentControlSwitchDoubleSend, }, { - name: "double-pay-strict", - strict: true, - testcase: testPaymentControlSwitchDoublePay, - }, - { - name: "fail-not-strict", - strict: false, - testcase: testPaymentControlSwitchFail, - }, - { - name: "double-send-not-strict", - strict: false, - testcase: testPaymentControlSwitchDoubleSend, - }, - { - name: "double-pay-not-strict", - strict: false, + name: "double-pay", testcase: testPaymentControlSwitchDoublePay, }, } @@ -96,7 +77,7 @@ var paymentControlTests = []struct { func TestPaymentControls(t *testing.T) { for _, test := range paymentControlTests { t.Run(test.name, func(t *testing.T) { - test.testcase(t, test.strict) + test.testcase(t) }) } } @@ -104,7 +85,7 @@ func TestPaymentControls(t *testing.T) { // testPaymentControlSwitchFail checks that payment status returns to Grounded // status after failing, and that ClearForTakeoff allows another HTLC for the // same payment hash. -func testPaymentControlSwitchFail(t *testing.T, strict bool) { +func testPaymentControlSwitchFail(t *testing.T) { t.Parallel() db, err := initDB() @@ -112,7 +93,7 @@ func testPaymentControlSwitchFail(t *testing.T, strict bool) { t.Fatalf("unable to init db: %v", err) } - pControl := NewPaymentControl(strict, db) + pControl := NewPaymentControl(db) htlc, err := genHtlc() if err != nil { @@ -158,7 +139,7 @@ func testPaymentControlSwitchFail(t *testing.T, strict bool) { // testPaymentControlSwitchDoubleSend checks the ability of payment control to // prevent double sending of htlc message, when message is in StatusInFlight. -func testPaymentControlSwitchDoubleSend(t *testing.T, strict bool) { +func testPaymentControlSwitchDoubleSend(t *testing.T) { t.Parallel() db, err := initDB() @@ -166,7 +147,7 @@ func testPaymentControlSwitchDoubleSend(t *testing.T, strict bool) { t.Fatalf("unable to init db: %v", err) } - pControl := NewPaymentControl(strict, db) + pControl := NewPaymentControl(db) htlc, err := genHtlc() if err != nil { @@ -192,7 +173,7 @@ func testPaymentControlSwitchDoubleSend(t *testing.T, strict bool) { // TestPaymentControlSwitchDoublePay checks the ability of payment control to // prevent double payment. -func testPaymentControlSwitchDoublePay(t *testing.T, strict bool) { +func testPaymentControlSwitchDoublePay(t *testing.T) { t.Parallel() db, err := initDB() @@ -200,7 +181,7 @@ func testPaymentControlSwitchDoublePay(t *testing.T, strict bool) { t.Fatalf("unable to init db: %v", err) } - pControl := NewPaymentControl(strict, db) + pControl := NewPaymentControl(db) htlc, err := genHtlc() if err != nil { @@ -229,86 +210,6 @@ func testPaymentControlSwitchDoublePay(t *testing.T, strict bool) { } } -// TestPaymentControlNonStrictSuccessesWithoutInFlight checks that a non-strict -// payment control will allow calls to Success when no payment is in flight. This -// is necessary to gracefully handle the case in which the switch already sent -// out a payment for a particular payment hash in a prior db version that didn't -// have payment statuses. -func TestPaymentControlNonStrictSuccessesWithoutInFlight(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(false, db) - - htlc, err := genHtlc() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - if err := pControl.Success(htlc.PaymentHash); err != nil { - t.Fatalf("unable to mark payment hash success: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, StatusCompleted) - - err = pControl.Success(htlc.PaymentHash) - if err != ErrPaymentAlreadyCompleted { - t.Fatalf("unable to remark payment hash failed: %v", err) - } -} - -// TestPaymentControlNonStrictFailsWithoutInFlight checks that a non-strict -// payment control will allow calls to Fail when no payment is in flight. This -// is necessary to gracefully handle the case in which the switch already sent -// out a payment for a particular payment hash in a prior db version that didn't -// have payment statuses. -func TestPaymentControlNonStrictFailsWithoutInFlight(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(false, db) - - htlc, err := genHtlc() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - if err := pControl.Fail(htlc.PaymentHash); err != nil { - t.Fatalf("unable to mark payment hash failed: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, StatusGrounded) - - err = pControl.Fail(htlc.PaymentHash) - if err != nil { - t.Fatalf("unable to remark payment hash failed: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, StatusGrounded) - - err = pControl.Success(htlc.PaymentHash) - if err != nil { - t.Fatalf("unable to remark payment hash success: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, StatusCompleted) - - err = pControl.Fail(htlc.PaymentHash) - if err != ErrPaymentAlreadyCompleted { - t.Fatalf("unable to remark payment hash failed: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, StatusCompleted) -} - // TestPaymentControlStrictSuccessesWithoutInFlight checks that a strict payment // control will disallow calls to Success when no payment is in flight. func TestPaymentControlStrictSuccessesWithoutInFlight(t *testing.T) { @@ -319,7 +220,7 @@ func TestPaymentControlStrictSuccessesWithoutInFlight(t *testing.T) { t.Fatalf("unable to init db: %v", err) } - pControl := NewPaymentControl(true, db) + pControl := NewPaymentControl(db) htlc, err := genHtlc() if err != nil { @@ -344,7 +245,7 @@ func TestPaymentControlStrictFailsWithoutInFlight(t *testing.T) { t.Fatalf("unable to init db: %v", err) } - pControl := NewPaymentControl(true, db) + pControl := NewPaymentControl(db) htlc, err := genHtlc() if err != nil {