diff --git a/channeldb/control_tower_test.go b/channeldb/control_tower_test.go index 2d6e71d6..612a058e 100644 --- a/channeldb/control_tower_test.go +++ b/channeldb/control_tower_test.go @@ -1,15 +1,18 @@ package channeldb import ( + "bytes" "crypto/rand" "fmt" "io" "io/ioutil" + "reflect" "testing" "time" "github.com/btcsuite/fastsha256" "github.com/coreos/bbolt" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lntypes" ) @@ -58,43 +61,10 @@ func genInfo() (*PaymentCreationInfo, *PaymentAttemptInfo, }, preimage, nil } -type paymentControlTestCase func(*testing.T) - -var paymentControlTests = []struct { - name string - testcase paymentControlTestCase -}{ - { - name: "fail", - testcase: testPaymentControlSwitchFail, - }, - { - name: "double-send", - testcase: testPaymentControlSwitchDoubleSend, - }, - { - name: "double-pay", - testcase: testPaymentControlSwitchDoublePay, - }, -} - -// TestPaymentControls runs a set of common tests against both the strict and -// non-strict payment control instances. This ensures that the two both behave -// identically when making the expected state-transitions of the stricter -// implementation. Behavioral differences in the strict and non-strict -// implementations are tested separately. -func TestPaymentControls(t *testing.T) { - for _, test := range paymentControlTests { - t.Run(test.name, func(t *testing.T) { - test.testcase(t) - }) - } -} - -// testPaymentControlSwitchFail checks that payment status returns to Grounded -// status after failing, and that ClearForTakeoff allows another HTLC for the +// TestPaymentControlSwitchFail checks that payment status returns to Failed +// status after failing, and that InitPayment allows another HTLC for the // same payment hash. -func testPaymentControlSwitchFail(t *testing.T) { +func TestPaymentControlSwitchFail(t *testing.T) { t.Parallel() db, err := initDB() @@ -104,7 +74,7 @@ func testPaymentControlSwitchFail(t *testing.T) { pControl := NewPaymentControl(db) - info, _, preimg, err := genInfo() + info, attempt, preimg, err := genInfo() if err != nil { t.Fatalf("unable to generate htlc message: %v", err) } @@ -116,14 +86,20 @@ func testPaymentControlSwitchFail(t *testing.T) { } assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + ) - // Fail the payment, which should moved it to Grounded. + // Fail the payment, which should moved it to Failed. if err := pControl.Fail(info.PaymentHash); err != nil { t.Fatalf("unable to fail payment hash: %v", err) } // Verify the status is indeed Failed. assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + ) // Sends the htlc again, which should succeed since the prior payment // failed. @@ -133,6 +109,20 @@ func testPaymentControlSwitchFail(t *testing.T) { } assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + ) + + // Record a new attempt. + attempt.PaymentID = 2 + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, + ) // Verifies that status was changed to StatusCompleted. if err := pControl.Success(info.PaymentHash, preimg); err != nil { @@ -140,6 +130,7 @@ func testPaymentControlSwitchFail(t *testing.T) { } assertPaymentStatus(t, db, info.PaymentHash, StatusCompleted) + assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg) // Attempt a final payment, which should now fail since the prior // payment succeed. @@ -149,9 +140,9 @@ func testPaymentControlSwitchFail(t *testing.T) { } } -// testPaymentControlSwitchDoubleSend checks the ability of payment control to +// TestPaymentControlSwitchDoubleSend checks the ability of payment control to // prevent double sending of htlc message, when message is in StatusInFlight. -func testPaymentControlSwitchDoubleSend(t *testing.T) { +func TestPaymentControlSwitchDoubleSend(t *testing.T) { t.Parallel() db, err := initDB() @@ -161,7 +152,7 @@ func testPaymentControlSwitchDoubleSend(t *testing.T) { pControl := NewPaymentControl(db) - info, _, _, err := genInfo() + info, attempt, preimg, err := genInfo() if err != nil { t.Fatalf("unable to generate htlc message: %v", err) } @@ -174,6 +165,9 @@ func testPaymentControlSwitchDoubleSend(t *testing.T) { } assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + ) // Try to initiate double sending of htlc message with the same // payment hash, should result in error indicating that payment has @@ -183,53 +177,40 @@ func testPaymentControlSwitchDoubleSend(t *testing.T) { t.Fatalf("payment control wrong behaviour: " + "double sending must trigger ErrPaymentInFlight error") } -} -// TestPaymentControlSwitchDoublePay checks the ability of payment control to -// prevent double payment. -func testPaymentControlSwitchDoublePay(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - info, _, preimg, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Sends base htlc message which initiate StatusInFlight. - err = pControl.InitPayment(info.PaymentHash, info) + // Record an attempt. + err = pControl.RegisterAttempt(info.PaymentHash, attempt) if err != nil { t.Fatalf("unable to send htlc message: %v", err) } - - // Verify that payment is InFlight. assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, + ) - // Move payment to completed status, second payment should return error. - err = pControl.Success(info.PaymentHash, preimg) - if err != nil { - t.Fatalf("error shouldn't have been received, got: %v", err) + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrPaymentInFlight { + t.Fatalf("payment control wrong behaviour: " + + "double sending must trigger ErrPaymentInFlight error") } - // Verify that payment is Completed. + // After settling, the error should be ErrAlreadyPaid. + if err := pControl.Success(info.PaymentHash, preimg); err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } assertPaymentStatus(t, db, info.PaymentHash, StatusCompleted) + assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg) err = pControl.InitPayment(info.PaymentHash, info) if err != ErrAlreadyPaid { - t.Fatalf("payment control wrong behaviour:" + - " double payment must trigger ErrAlreadyPaid") + t.Fatalf("unable to send htlc message: %v", err) } } -// TestPaymentControlStrictSuccessesWithoutInFlight checks that a strict payment +// TestPaymentControlSuccessesWithoutInFlight checks that the payment // control will disallow calls to Success when no payment is in flight. -func TestPaymentControlStrictSuccessesWithoutInFlight(t *testing.T) { +func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { t.Parallel() db, err := initDB() @@ -244,17 +225,19 @@ func TestPaymentControlStrictSuccessesWithoutInFlight(t *testing.T) { t.Fatalf("unable to generate htlc message: %v", err) } + // Attempt to complete the payment should fail. err = pControl.Success(info.PaymentHash, preimg) if err != ErrPaymentNotInitiated { t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) } assertPaymentStatus(t, db, info.PaymentHash, StatusGrounded) + assertPaymentInfo(t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}) } -// TestPaymentControlStrictFailsWithoutInFlight checks that a strict payment +// TestPaymentControlFailsWithoutInFlight checks that a strict payment // control will disallow calls to Fail when no payment is in flight. -func TestPaymentControlStrictFailsWithoutInFlight(t *testing.T) { +func TestPaymentControlFailsWithoutInFlight(t *testing.T) { t.Parallel() db, err := initDB() @@ -269,12 +252,14 @@ func TestPaymentControlStrictFailsWithoutInFlight(t *testing.T) { t.Fatalf("unable to generate htlc message: %v", err) } + // Calling Fail should return an error. err = pControl.Fail(info.PaymentHash) if err != ErrPaymentNotInitiated { t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) } assertPaymentStatus(t, db, info.PaymentHash, StatusGrounded) + assertPaymentInfo(t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}) } func assertPaymentStatus(t *testing.T, db *DB, @@ -307,3 +292,116 @@ func assertPaymentStatus(t *testing.T, db *DB, expStatus, paymentStatus) } } + +func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error { + b := bucket.Get(paymentCreationInfoKey) + switch { + case b == nil && c == nil: + return nil + case b == nil: + return fmt.Errorf("expected creation info not found") + case c == nil: + return fmt.Errorf("unexpected creation info found") + } + + r := bytes.NewReader(b) + c2, err := deserializePaymentCreationInfo(r) + if err != nil { + fmt.Println("creation info err: ", err) + return err + } + if !reflect.DeepEqual(c, c2) { + return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v", + spew.Sdump(c), spew.Sdump(c2)) + } + + return nil +} + +func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error { + b := bucket.Get(paymentAttemptInfoKey) + switch { + case b == nil && a == nil: + return nil + case b == nil: + return fmt.Errorf("expected attempt info not found") + case a == nil: + return fmt.Errorf("unexpected attempt info found") + } + + r := bytes.NewReader(b) + a2, err := deserializePaymentAttemptInfo(r) + if err != nil { + return err + } + if !reflect.DeepEqual(a, a2) { + return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v", + spew.Sdump(a), spew.Sdump(a2)) + } + + return nil +} + +func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error { + zero := lntypes.Preimage{} + b := bucket.Get(paymentSettleInfoKey) + switch { + case b == nil && preimg == zero: + return nil + case b == nil: + return fmt.Errorf("expected preimage not found") + case preimg == zero: + return fmt.Errorf("unexpected preimage found") + } + + var pre2 lntypes.Preimage + copy(pre2[:], b[:]) + if preimg != pre2 { + return fmt.Errorf("Preimages don't match: %x vs %x", + preimg, pre2) + } + + return nil +} + +func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash, + c *PaymentCreationInfo, a *PaymentAttemptInfo, s lntypes.Preimage) { + + t.Helper() + + err := db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil && c == nil { + return nil + } + if payments == nil { + return fmt.Errorf("sent payments not found") + } + + bucket := payments.Bucket(hash[:]) + if bucket == nil && c == nil { + return nil + } + + if bucket == nil { + return fmt.Errorf("payment not found") + } + + if err := checkPaymentCreationInfo(bucket, c); err != nil { + return err + } + + if err := checkPaymentAttemptInfo(bucket, a); err != nil { + return err + } + + if err := checkSettleInfo(bucket, s); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatalf("assert payment info failed: %v", err) + } + +}