channeldb/test: refactor payment control test

Previously this was tested as a white box. Database access methods were
duplicated as test code and compared to the return value of the code
under test. This approaches leads to brittle test because it relies
heavily on implementation details. This commit changes this and prepares
for additional test coverage being added in later commits.
This commit is contained in:
Joost Jager 2020-02-20 14:43:25 +01:00
parent bee2380441
commit c29b74168f
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7

@ -1,7 +1,6 @@
package channeldb
import (
"bytes"
"crypto/rand"
"fmt"
"io"
@ -11,7 +10,6 @@ import (
"time"
"github.com/btcsuite/fastsha256"
"github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lntypes"
)
@ -85,9 +83,9 @@ func TestPaymentControlSwitchFail(t *testing.T) {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
t, pControl, info.PaymentHash, info, nil, lntypes.Preimage{},
nil,
)
@ -99,9 +97,9 @@ func TestPaymentControlSwitchFail(t *testing.T) {
}
// Verify the status is indeed Failed.
assertPaymentStatus(t, db, info.PaymentHash, StatusFailed)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusFailed)
assertPaymentInfo(
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
t, pControl, info.PaymentHash, info, nil, lntypes.Preimage{},
&failReason,
)
@ -112,9 +110,9 @@ func TestPaymentControlSwitchFail(t *testing.T) {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
t, pControl, info.PaymentHash, info, nil, lntypes.Preimage{},
nil,
)
@ -124,9 +122,9 @@ func TestPaymentControlSwitchFail(t *testing.T) {
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, attempt, lntypes.Preimage{},
t, pControl, info.PaymentHash, info, attempt, lntypes.Preimage{},
nil,
)
@ -149,8 +147,8 @@ func TestPaymentControlSwitchFail(t *testing.T) {
spew.Sdump(payment.HTLCs[0].Route), err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusSucceeded)
assertPaymentInfo(t, pControl, info.PaymentHash, info, attempt, preimg, nil)
// Attempt a final payment, which should now fail since the prior
// payment succeed.
@ -184,9 +182,9 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
t, pControl, info.PaymentHash, info, nil, lntypes.Preimage{},
nil,
)
@ -204,9 +202,9 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) {
if err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, attempt, lntypes.Preimage{},
t, pControl, info.PaymentHash, info, attempt, lntypes.Preimage{},
nil,
)
@ -221,8 +219,8 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) {
if _, err := pControl.Success(info.PaymentHash, preimg); err != nil {
t.Fatalf("error shouldn't have been received, got: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusSucceeded)
assertPaymentInfo(t, pControl, info.PaymentHash, info, attempt, preimg, nil)
err = pControl.InitPayment(info.PaymentHash, info)
if err != ErrAlreadyPaid {
@ -253,11 +251,7 @@ func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) {
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown)
assertPaymentInfo(
t, db, info.PaymentHash, nil, nil, lntypes.Preimage{},
nil,
)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown)
}
// TestPaymentControlFailsWithoutInFlight checks that a strict payment
@ -283,10 +277,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) {
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown)
assertPaymentInfo(
t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, nil,
)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusUnknown)
}
// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only
@ -344,9 +335,9 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
}
// Verify the status is indeed Failed.
assertPaymentStatus(t, db, info.PaymentHash, StatusFailed)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusFailed)
assertPaymentInfo(
t, db, info.PaymentHash, info, attempt,
t, pControl, info.PaymentHash, info, attempt,
lntypes.Preimage{}, &failReason,
)
} else if p.success {
@ -356,14 +347,14 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
t.Fatalf("error shouldn't have been received, got: %v", err)
}
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusSucceeded)
assertPaymentInfo(
t, db, info.PaymentHash, info, attempt, preimg, nil,
t, pControl, info.PaymentHash, info, attempt, preimg, nil,
)
} else {
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
assertPaymentStatus(t, pControl, info.PaymentHash, StatusInFlight)
assertPaymentInfo(
t, db, info.PaymentHash, info, attempt,
t, pControl, info.PaymentHash, info, attempt,
lntypes.Preimage{}, nil,
)
}
@ -390,166 +381,76 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) {
}
}
func assertPaymentStatus(t *testing.T, db *DB,
hash [32]byte, expStatus PaymentStatus) {
// assertPaymentStatus retrieves the status of the payment referred to by hash
// and compares it with the expected state.
func assertPaymentStatus(t *testing.T, p *PaymentControl,
hash lntypes.Hash, expStatus PaymentStatus) {
t.Helper()
var paymentStatus = StatusUnknown
err := db.View(func(tx *bbolt.Tx) error {
payments := tx.Bucket(paymentsRootBucket)
if payments == nil {
return nil
payment, err := p.FetchPayment(hash)
if expStatus == StatusUnknown && err == ErrPaymentNotInitiated {
return
}
bucket := payments.Bucket(hash[:])
if bucket == nil {
return nil
}
// Get the existing status of this payment, if any.
paymentStatus = fetchPaymentStatus(bucket)
return nil
})
if err != nil {
t.Fatalf("unable to fetch payment status: %v", err)
t.Fatal(err)
}
if paymentStatus != expStatus {
if payment.Status != expStatus {
t.Fatalf("payment status mismatch: expected %v, got %v",
expStatus, paymentStatus)
expStatus, payment.Status)
}
}
func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error {
b := bucket.Get(paymentCreationInfoKey)
switch {
case b == nil && c == nil:
return nil
case b == nil:
return fmt.Errorf("expected creation info not found")
case c == nil:
return fmt.Errorf("unexpected creation info found")
}
r := bytes.NewReader(b)
c2, err := deserializePaymentCreationInfo(r)
if err != nil {
return err
}
if !reflect.DeepEqual(c, c2) {
return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v",
spew.Sdump(c), spew.Sdump(c2))
}
return nil
}
func checkHTLCAttemptInfo(bucket *bbolt.Bucket, a *HTLCAttemptInfo) error {
b := bucket.Get(paymentAttemptInfoKey)
switch {
case b == nil && a == nil:
return nil
case b == nil:
return fmt.Errorf("expected attempt info not found")
case a == nil:
return fmt.Errorf("unexpected attempt info found")
}
r := bytes.NewReader(b)
a2, err := deserializeHTLCAttemptInfo(r)
if err != nil {
return err
}
return assertRouteEqual(&a.Route, &a2.Route)
}
func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error {
zero := lntypes.Preimage{}
b := bucket.Get(paymentSettleInfoKey)
switch {
case b == nil && preimg == zero:
return nil
case b == nil:
return fmt.Errorf("expected preimage not found")
case preimg == zero:
return fmt.Errorf("unexpected preimage found")
}
var pre2 lntypes.Preimage
copy(pre2[:], b[:])
if preimg != pre2 {
return fmt.Errorf("Preimages don't match: %x vs %x",
preimg, pre2)
}
return nil
}
func checkFailInfo(bucket *bbolt.Bucket, failReason *FailureReason) error {
b := bucket.Get(paymentFailInfoKey)
switch {
case b == nil && failReason == nil:
return nil
case b == nil:
return fmt.Errorf("expected fail info not found")
case failReason == nil:
return fmt.Errorf("unexpected fail info found")
}
failReason2 := FailureReason(b[0])
if *failReason != failReason2 {
return fmt.Errorf("Failure infos don't match: %v vs %v",
*failReason, failReason2)
}
return nil
}
func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash,
// assertPaymentInfo retrieves the payment referred to by hash and verifies the
// expected values.
func assertPaymentInfo(t *testing.T, p *PaymentControl, hash lntypes.Hash,
c *PaymentCreationInfo, a *HTLCAttemptInfo, s lntypes.Preimage,
f *FailureReason) {
t.Helper()
err := db.View(func(tx *bbolt.Tx) error {
payments := tx.Bucket(paymentsRootBucket)
if payments == nil && c == nil {
return nil
}
if payments == nil {
return fmt.Errorf("sent payments not found")
}
bucket := payments.Bucket(hash[:])
if bucket == nil && c == nil {
return nil
}
if bucket == nil {
return fmt.Errorf("payment not found")
}
if err := checkPaymentCreationInfo(bucket, c); err != nil {
return err
}
if err := checkHTLCAttemptInfo(bucket, a); err != nil {
return err
}
if err := checkSettleInfo(bucket, s); err != nil {
return err
}
if err := checkFailInfo(bucket, f); err != nil {
return err
}
return nil
})
payment, err := p.FetchPayment(hash)
if err != nil {
t.Fatalf("assert payment info failed: %v", err)
t.Fatal(err)
}
if !reflect.DeepEqual(payment.Info, c) {
t.Fatalf("PaymentCreationInfos don't match: %v vs %v",
spew.Sdump(payment.Info), spew.Sdump(c))
}
if f != nil {
if *payment.FailureReason != *f {
t.Fatal("unexpected failure reason")
}
} else {
if payment.FailureReason != nil {
t.Fatal("unexpected failure reason")
}
}
if a == nil {
if len(payment.HTLCs) > 0 {
t.Fatal("expected no htlcs")
}
return
}
htlc := payment.HTLCs[0]
if err := assertRouteEqual(&htlc.Route, &a.Route); err != nil {
t.Fatal("routes do not match")
}
var zeroPreimage = lntypes.Preimage{}
if s != zeroPreimage {
if htlc.Settle.Preimage != s {
t.Fatalf("Preimages don't match: %x vs %x",
htlc.Settle.Preimage, s)
}
} else {
if htlc.Settle != nil {
t.Fatal("expected no settle info")
}
}
}