diff --git a/channeldb/codec.go b/channeldb/codec.go index ca5cfeed..e9afe0e1 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -148,6 +148,12 @@ func WriteElement(w io.Writer, element interface{}) error { return err } + case *btcec.PrivateKey: + b := e.Serialize() + if _, err := w.Write(b); err != nil { + return err + } + case *btcec.PublicKey: b := e.SerializeCompressed() if _, err := w.Write(b); err != nil { @@ -320,6 +326,15 @@ func ReadElement(r io.Reader, element interface{}) error { *e = lnwire.MilliSatoshi(a) + case **btcec.PrivateKey: + var b [btcec.PrivKeyBytesLen]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + priv, _ := btcec.PrivKeyFromBytes(btcec.S256(), b[:]) + *e = priv + case **btcec.PublicKey: var b [btcec.PubKeyBytesLenCompressed]byte if _, err := io.ReadFull(r, b[:]); err != nil { diff --git a/channeldb/control_tower.go b/channeldb/control_tower.go new file mode 100644 index 00000000..97b44180 --- /dev/null +++ b/channeldb/control_tower.go @@ -0,0 +1,463 @@ +package channeldb + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" +) + +var ( + // ErrAlreadyPaid signals we have already paid this payment hash. + ErrAlreadyPaid = errors.New("invoice is already paid") + + // ErrPaymentInFlight signals that payment for this payment hash is + // already "in flight" on the network. + ErrPaymentInFlight = errors.New("payment is in transition") + + // ErrPaymentNotInitiated is returned if payment wasn't initiated in + // switch. + ErrPaymentNotInitiated = errors.New("payment isn't initiated") + + // ErrPaymentAlreadySucceeded is returned in the event we attempt to + // change the status of a payment already succeeded. + ErrPaymentAlreadySucceeded = errors.New("payment is already succeeded") + + // ErrPaymentAlreadyFailed is returned in the event we attempt to + // re-fail a failed payment. + ErrPaymentAlreadyFailed = errors.New("payment has already failed") + + // ErrUnknownPaymentStatus is returned when we do not recognize the + // existing state of a payment. + ErrUnknownPaymentStatus = errors.New("unknown payment status") +) + +// ControlTower tracks all outgoing payments made, whose primary purpose is to +// prevent duplicate payments to the same payment hash. In production, a +// persistent implementation is preferred so that tracking can survive across +// restarts. Payments are transitioned through various payment states, and the +// ControlTower interface provides access to driving the state transitions. +type ControlTower interface { + // InitPayment atomically moves the payment into the InFlight state. + // This method checks that no suceeded payment exist for this payment + // hash. + InitPayment(lntypes.Hash, *PaymentCreationInfo) error + + // RegisterAttempt atomically records the provided PaymentAttemptInfo. + RegisterAttempt(lntypes.Hash, *PaymentAttemptInfo) error + + // Success transitions a payment into the Succeeded state. After + // invoking this method, InitPayment should always return an error to + // prevent us from making duplicate payments to the same payment hash. + // The provided preimage is atomically saved to the DB for record + // keeping. + Success(lntypes.Hash, lntypes.Preimage) error + + // Fail transitions a payment into the Failed state, and records the + // reason the payment failed. After invoking this method, InitPayment + // should return nil on its next call for this payment hash, allowing + // the switch to make a subsequent payment. + Fail(lntypes.Hash, FailureReason) error + + // FetchInFlightPayments returns all payments with status InFlight. + FetchInFlightPayments() ([]*InFlightPayment, error) +} + +// paymentControl is persistent implementation of ControlTower to restrict +// double payment sending. +type paymentControl struct { + db *DB +} + +// NewPaymentControl creates a new instance of the paymentControl. +func NewPaymentControl(db *DB) ControlTower { + return &paymentControl{ + db: db, + } +} + +// InitPayment checks or records the given PaymentCreationInfo with the DB, +// making sure it does not already exist as an in-flight payment. Then this +// method returns successfully, the payment is guranteeed to be in the InFlight +// state. +func (p *paymentControl) InitPayment(paymentHash lntypes.Hash, + info *PaymentCreationInfo) error { + + var b bytes.Buffer + if err := serializePaymentCreationInfo(&b, info); err != nil { + return err + } + infoBytes := b.Bytes() + + var updateErr error + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := createPaymentBucket(tx, paymentHash) + if err != nil { + return err + } + + // Get the existing status of this payment, if any. + paymentStatus := fetchPaymentStatus(bucket) + + switch paymentStatus { + + // We allow retrying failed payments. + case StatusFailed: + + // This is a new payment that is being initialized for the + // first time. + case StatusUnknown: + + // We already have an InFlight payment on the network. We will + // disallow any new payments. + case StatusInFlight: + updateErr = ErrPaymentInFlight + return nil + + // We've already succeeded a payment to this payment hash, + // forbid the switch from sending another. + case StatusSucceeded: + updateErr = ErrAlreadyPaid + return nil + + default: + updateErr = ErrUnknownPaymentStatus + return nil + } + + // Obtain a new sequence number for this payment. This is used + // to sort the payments in order of creation, and also acts as + // a unique identifier for each payment. + sequenceNum, err := nextPaymentSequence(tx) + if err != nil { + return err + } + + err = bucket.Put(paymentSequenceKey, sequenceNum) + if err != nil { + return err + } + + // Add the payment info to the bucket, which contains the + // static information for this payment + err = bucket.Put(paymentCreationInfoKey, infoBytes) + if err != nil { + return err + } + + // We'll delete any lingering attempt info to start with, in + // case we are initializing a payment that was attempted + // earlier, but left in a state where we could retry. + err = bucket.Delete(paymentAttemptInfoKey) + if err != nil { + return err + } + + // Also delete any lingering failure info now that we are + // re-attempting. + return bucket.Delete(paymentFailInfoKey) + }) + if err != nil { + return nil + } + + return updateErr +} + +// RegisterAttempt atomically records the provided PaymentAttemptInfo to the +// DB. +func (p *paymentControl) RegisterAttempt(paymentHash lntypes.Hash, + attempt *PaymentAttemptInfo) error { + + // Serialize the information before opening the db transaction. + var a bytes.Buffer + if err := serializePaymentAttemptInfo(&a, attempt); err != nil { + return err + } + attemptBytes := a.Bytes() + + var updateErr error + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err == ErrPaymentNotInitiated { + updateErr = ErrPaymentNotInitiated + return nil + } else if err != nil { + return err + } + + // We can only register attempts for payments that are + // in-flight. + if err := ensureInFlight(bucket); err != nil { + updateErr = err + return nil + } + + // Add the payment attempt to the payments bucket. + return bucket.Put(paymentAttemptInfoKey, attemptBytes) + }) + if err != nil { + return err + } + + return updateErr +} + +// Success transitions a payment into the Succeeded state. After invoking this +// method, InitPayment should always return an error to prevent us from making +// duplicate payments to the same payment hash. The provided preimage is +// atomically saved to the DB for record keeping. +func (p *paymentControl) Success(paymentHash lntypes.Hash, + preimage lntypes.Preimage) error { + + var updateErr error + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err == ErrPaymentNotInitiated { + updateErr = ErrPaymentNotInitiated + return nil + } else if err != nil { + return err + } + + // We can only mark in-flight payments as succeeded. + if err := ensureInFlight(bucket); err != nil { + updateErr = err + return nil + } + + // Record the successful payment info atomically to the + // payments record. + return bucket.Put(paymentSettleInfoKey, preimage[:]) + }) + if err != nil { + return err + } + + return updateErr + +} + +// Fail transitions a payment into the Failed state, and records the reason the +// payment failed. After invoking this method, InitPayment should return nil on +// its next call for this payment hash, allowing the switch to make a +// subsequent payment. +func (p *paymentControl) Fail(paymentHash lntypes.Hash, + reason FailureReason) error { + + var updateErr error + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err == ErrPaymentNotInitiated { + updateErr = ErrPaymentNotInitiated + return nil + } else if err != nil { + return err + } + + // We can only mark in-flight payments as failed. + if err := ensureInFlight(bucket); err != nil { + updateErr = err + return nil + } + + // Put the failure reason in the bucket for record keeping. + v := []byte{byte(reason)} + return bucket.Put(paymentFailInfoKey, v) + }) + if err != nil { + return err + } + + return updateErr +} + +// createPaymentBucket creates or fetches the sub-bucket assigned to this +// payment hash. +func createPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( + *bbolt.Bucket, error) { + + payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) + if err != nil { + return nil, err + } + + return payments.CreateBucketIfNotExists(paymentHash[:]) +} + +// fetchPaymentBucket fetches the sub-bucket assigned to this payment hash. If +// the bucket does not exist, it returns ErrPaymentNotInitiated. +func fetchPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( + *bbolt.Bucket, error) { + + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil, ErrPaymentNotInitiated + } + + bucket := payments.Bucket(paymentHash[:]) + if bucket == nil { + return nil, ErrPaymentNotInitiated + } + + return bucket, nil + +} + +// nextPaymentSequence returns the next sequence number to store for a new +// payment. +func nextPaymentSequence(tx *bbolt.Tx) ([]byte, error) { + payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) + if err != nil { + return nil, err + } + + seq, err := payments.NextSequence() + if err != nil { + return nil, err + } + + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, seq) + return b, nil +} + +// fetchPaymentStatus fetches the payment status of the payment. If the payment +// isn't found, it will default to "StatusUnknown". +func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus { + if bucket.Get(paymentSettleInfoKey) != nil { + return StatusSucceeded + } + + if bucket.Get(paymentFailInfoKey) != nil { + return StatusFailed + } + + if bucket.Get(paymentCreationInfoKey) != nil { + return StatusInFlight + } + + return StatusUnknown +} + +// ensureInFlight checks whether the payment found in the given bucket has +// status InFlight, and returns an error otherwise. This should be used to +// ensure we only mark in-flight payments as succeeded or failed. +func ensureInFlight(bucket *bbolt.Bucket) error { + paymentStatus := fetchPaymentStatus(bucket) + + switch { + + // The payment was indeed InFlight, return. + case paymentStatus == StatusInFlight: + return nil + + // Our records show the payment as unknown, meaning it never + // should have left the switch. + case paymentStatus == StatusUnknown: + return ErrPaymentNotInitiated + + // The payment succeeded previously. + case paymentStatus == StatusSucceeded: + return ErrPaymentAlreadySucceeded + + // The payment was already failed. + case paymentStatus == StatusFailed: + return ErrPaymentAlreadyFailed + + default: + return ErrUnknownPaymentStatus + } +} + +// InFlightPayment is a wrapper around a payment that has status InFlight. +type InFlightPayment struct { + // Info is the PaymentCreationInfo of the in-flight payment. + Info *PaymentCreationInfo + + // Attempt contains information about the last payment attempt that was + // made to this payment hash. + // + // NOTE: Might be nil. + Attempt *PaymentAttemptInfo +} + +// FetchInFlightPayments returns all payments with status InFlight. +func (p *paymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { + var inFlights []*InFlightPayment + err := p.db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil + } + + return payments.ForEach(func(k, _ []byte) error { + bucket := payments.Bucket(k) + if bucket == nil { + return fmt.Errorf("non bucket element") + } + + // If the status is not InFlight, we can return early. + paymentStatus := fetchPaymentStatus(bucket) + if paymentStatus != StatusInFlight { + return nil + } + + var ( + inFlight = &InFlightPayment{} + err error + ) + + // Get the CreationInfo. + b := bucket.Get(paymentCreationInfoKey) + if b == nil { + return fmt.Errorf("unable to find creation " + + "info for inflight payment") + } + + r := bytes.NewReader(b) + inFlight.Info, err = deserializePaymentCreationInfo(r) + if err != nil { + return err + } + + // Now get the attempt info, which may or may not be + // available. + attempt := bucket.Get(paymentAttemptInfoKey) + if attempt != nil { + r = bytes.NewReader(attempt) + inFlight.Attempt, err = deserializePaymentAttemptInfo(r) + if err != nil { + return err + } + } + + inFlights = append(inFlights, inFlight) + return nil + }) + }) + if err != nil { + return nil, err + } + + return inFlights, nil +} diff --git a/channeldb/control_tower_test.go b/channeldb/control_tower_test.go new file mode 100644 index 00000000..370300d7 --- /dev/null +++ b/channeldb/control_tower_test.go @@ -0,0 +1,546 @@ +package channeldb + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "io/ioutil" + "reflect" + "testing" + "time" + + "github.com/btcsuite/fastsha256" + "github.com/coreos/bbolt" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lntypes" +) + +func initDB() (*DB, error) { + tempPath, err := ioutil.TempDir("", "switchdb") + if err != nil { + return nil, err + } + + db, err := Open(tempPath) + if err != nil { + return nil, err + } + + return db, err +} + +func genPreimage() ([32]byte, error) { + var preimage [32]byte + if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { + return preimage, err + } + return preimage, nil +} + +func genInfo() (*PaymentCreationInfo, *PaymentAttemptInfo, + lntypes.Preimage, error) { + + preimage, err := genPreimage() + if err != nil { + return nil, nil, preimage, fmt.Errorf("unable to "+ + "generate preimage: %v", err) + } + + rhash := fastsha256.Sum256(preimage[:]) + return &PaymentCreationInfo{ + PaymentHash: rhash, + Value: 1, + CreationDate: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte("hola"), + }, + &PaymentAttemptInfo{ + PaymentID: 1, + SessionKey: priv, + Route: testRoute, + }, preimage, nil +} + +// TestPaymentControlSwitchFail checks that payment status returns to Failed +// status after failing, and that InitPayment allows another HTLC for the +// same payment hash. +func TestPaymentControlSwitchFail(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + nil, + ) + + // Fail the payment, which should moved it to Failed. + failReason := FailureReasonNoRoute + err = pControl.Fail(info.PaymentHash, failReason) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Verify the status is indeed Failed. + assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + &failReason, + ) + + // Sends the htlc again, which should succeed since the prior payment + // failed. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + nil, + ) + + // Record a new attempt. + attempt.PaymentID = 2 + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, + nil, + ) + + // Verifies that status was changed to StatusSucceeded. + if err := pControl.Success(info.PaymentHash, preimg); err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) + + // Attempt a final payment, which should now fail since the prior + // payment succeed. + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrAlreadyPaid { + t.Fatalf("unable to send htlc message: %v", err) + } +} + +// TestPaymentControlSwitchDoubleSend checks the ability of payment control to +// prevent double sending of htlc message, when message is in StatusInFlight. +func TestPaymentControlSwitchDoubleSend(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate base status and move it to + // StatusInFlight and verifies that it was changed. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + nil, + ) + + // Try to initiate double sending of htlc message with the same + // payment hash, should result in error indicating that payment has + // already been sent. + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrPaymentInFlight { + t.Fatalf("payment control wrong behaviour: " + + "double sending must trigger ErrPaymentInFlight error") + } + + // Record an attempt. + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, + nil, + ) + + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrPaymentInFlight { + t.Fatalf("payment control wrong behaviour: " + + "double sending must trigger ErrPaymentInFlight error") + } + + // After settling, the error should be ErrAlreadyPaid. + if err := pControl.Success(info.PaymentHash, preimg); err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) + + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrAlreadyPaid { + t.Fatalf("unable to send htlc message: %v", err) + } +} + +// TestPaymentControlSuccessesWithoutInFlight checks that the payment +// control will disallow calls to Success when no payment is in flight. +func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, _, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Attempt to complete the payment should fail. + err = pControl.Success(info.PaymentHash, preimg) + if err != ErrPaymentNotInitiated { + t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) + assertPaymentInfo( + t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, + nil, + ) +} + +// TestPaymentControlFailsWithoutInFlight checks that a strict payment +// control will disallow calls to Fail when no payment is in flight. +func TestPaymentControlFailsWithoutInFlight(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, _, _, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Calling Fail should return an error. + err = pControl.Fail(info.PaymentHash, FailureReasonNoRoute) + if err != ErrPaymentNotInitiated { + t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) + assertPaymentInfo( + t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, nil, + ) +} + +// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only +// deletes payments from the database that are not in-flight. +func TestPaymentControlDeleteNonInFligt(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + payments := []struct { + failed bool + success bool + }{ + { + failed: true, + success: false, + }, + { + failed: false, + success: true, + }, + { + failed: false, + success: false, + }, + } + + for _, p := range payments { + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + if p.failed { + // Fail the payment, which should moved it to Failed. + failReason := FailureReasonNoRoute + err = pControl.Fail(info.PaymentHash, failReason) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Verify the status is indeed Failed. + assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, + lntypes.Preimage{}, &failReason, + ) + } else if p.success { + // Verifies that status was changed to StatusSucceeded. + err := pControl.Success(info.PaymentHash, preimg) + if err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, preimg, nil, + ) + } else { + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, + lntypes.Preimage{}, nil, + ) + } + } + + // Delete payments. + if err := db.DeletePayments(); err != nil { + t.Fatal(err) + } + + // This should leave the in-flight payment. + dbPayments, err := db.FetchPayments() + if err != nil { + t.Fatal(err) + } + + if len(dbPayments) != 1 { + t.Fatalf("expected one payment, got %d", len(dbPayments)) + } + + status := dbPayments[0].Status + if status != StatusInFlight { + t.Fatalf("expected in-fligth status, got %v", status) + } +} + +func assertPaymentStatus(t *testing.T, db *DB, + hash [32]byte, expStatus PaymentStatus) { + + t.Helper() + + var paymentStatus = StatusUnknown + err := db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil + } + + bucket := payments.Bucket(hash[:]) + if bucket == nil { + return nil + } + + // Get the existing status of this payment, if any. + paymentStatus = fetchPaymentStatus(bucket) + return nil + }) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != expStatus { + t.Fatalf("payment status mismatch: expected %v, got %v", + expStatus, paymentStatus) + } +} + +func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error { + b := bucket.Get(paymentCreationInfoKey) + switch { + case b == nil && c == nil: + return nil + case b == nil: + return fmt.Errorf("expected creation info not found") + case c == nil: + return fmt.Errorf("unexpected creation info found") + } + + r := bytes.NewReader(b) + c2, err := deserializePaymentCreationInfo(r) + if err != nil { + fmt.Println("creation info err: ", err) + return err + } + if !reflect.DeepEqual(c, c2) { + return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v", + spew.Sdump(c), spew.Sdump(c2)) + } + + return nil +} + +func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error { + b := bucket.Get(paymentAttemptInfoKey) + switch { + case b == nil && a == nil: + return nil + case b == nil: + return fmt.Errorf("expected attempt info not found") + case a == nil: + return fmt.Errorf("unexpected attempt info found") + } + + r := bytes.NewReader(b) + a2, err := deserializePaymentAttemptInfo(r) + if err != nil { + return err + } + if !reflect.DeepEqual(a, a2) { + return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v", + spew.Sdump(a), spew.Sdump(a2)) + } + + return nil +} + +func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error { + zero := lntypes.Preimage{} + b := bucket.Get(paymentSettleInfoKey) + switch { + case b == nil && preimg == zero: + return nil + case b == nil: + return fmt.Errorf("expected preimage not found") + case preimg == zero: + return fmt.Errorf("unexpected preimage found") + } + + var pre2 lntypes.Preimage + copy(pre2[:], b[:]) + if preimg != pre2 { + return fmt.Errorf("Preimages don't match: %x vs %x", + preimg, pre2) + } + + return nil +} + +func checkFailInfo(bucket *bbolt.Bucket, failReason *FailureReason) error { + b := bucket.Get(paymentFailInfoKey) + switch { + case b == nil && failReason == nil: + return nil + case b == nil: + return fmt.Errorf("expected fail info not found") + case failReason == nil: + return fmt.Errorf("unexpected fail info found") + } + + failReason2 := FailureReason(b[0]) + if *failReason != failReason2 { + return fmt.Errorf("Failure infos don't match: %v vs %v", + *failReason, failReason2) + } + + return nil +} + +func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash, + c *PaymentCreationInfo, a *PaymentAttemptInfo, s lntypes.Preimage, + f *FailureReason) { + + t.Helper() + + err := db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil && c == nil { + return nil + } + if payments == nil { + return fmt.Errorf("sent payments not found") + } + + bucket := payments.Bucket(hash[:]) + if bucket == nil && c == nil { + return nil + } + + if bucket == nil { + return fmt.Errorf("payment not found") + } + + if err := checkPaymentCreationInfo(bucket, c); err != nil { + return err + } + + if err := checkPaymentAttemptInfo(bucket, a); err != nil { + return err + } + + if err := checkSettleInfo(bucket, s); err != nil { + return err + } + + if err := checkFailInfo(bucket, f); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatalf("assert payment info failed: %v", err) + } + +} diff --git a/channeldb/db.go b/channeldb/db.go index aecb75e4..e9a9a185 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -96,6 +96,13 @@ var ( number: 8, migration: migrateGossipMessageStoreKeys, }, + { + // The DB version where the payments and payment + // statuses are moved to being stored in a combined + // bucket. + number: 9, + migration: migrateOutgoingPayments, + }, } // Big endian is the preferred byte order, due to cursor scans over diff --git a/channeldb/legacy_serialization.go b/channeldb/legacy_serialization.go index 2abb3f04..3ca9b5ec 100644 --- a/channeldb/legacy_serialization.go +++ b/channeldb/legacy_serialization.go @@ -1,6 +1,8 @@ package channeldb -import "io" +import ( + "io" +) // deserializeCloseChannelSummaryV6 reads the v6 database format for // ChannelCloseSummary. diff --git a/channeldb/meta_test.go b/channeldb/meta_test.go index 76d0cb25..ab3c410e 100644 --- a/channeldb/meta_test.go +++ b/channeldb/meta_test.go @@ -20,6 +20,16 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), t.Fatal(err) } + // Create a test node that will be our source node. + testNode, err := createTestVertex(cdb) + if err != nil { + t.Fatal(err) + } + graph := cdb.ChannelGraph() + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatal(err) + } + // beforeMigration usually used for populating the database // with test data. beforeMigration(cdb) diff --git a/channeldb/migration_09_legacy_serialization.go b/channeldb/migration_09_legacy_serialization.go new file mode 100644 index 00000000..52e765ed --- /dev/null +++ b/channeldb/migration_09_legacy_serialization.go @@ -0,0 +1,255 @@ +package channeldb + +import ( + "bytes" + "encoding/binary" + "io" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // paymentBucket is the name of the bucket within the database that + // stores all data related to payments. + // + // Within the payments bucket, each invoice is keyed by its invoice ID + // which is a monotonically increasing uint64. BoltDB's sequence + // feature is used for generating monotonically increasing id. + // + // NOTE: Deprecated. Kept around for migration purposes. + paymentBucket = []byte("payments") + + // paymentStatusBucket is the name of the bucket within the database + // that stores the status of a payment indexed by the payment's + // preimage. + // + // NOTE: Deprecated. Kept around for migration purposes. + paymentStatusBucket = []byte("payment-status") +) + +// outgoingPayment represents a successful payment between the daemon and a +// remote node. Details such as the total fee paid, and the time of the payment +// are stored. +// +// NOTE: Deprecated. Kept around for migration purposes. +type outgoingPayment struct { + Invoice + + // Fee is the total fee paid for the payment in milli-satoshis. + Fee lnwire.MilliSatoshi + + // TotalTimeLock is the total cumulative time-lock in the HTLC extended + // from the second-to-last hop to the destination. + TimeLockLength uint32 + + // Path encodes the path the payment took through the network. The path + // excludes the outgoing node and consists of the hex-encoded + // compressed public key of each of the nodes involved in the payment. + Path [][33]byte + + // PaymentPreimage is the preImage of a successful payment. This is used + // to calculate the PaymentHash as well as serve as a proof of payment. + PaymentPreimage [32]byte +} + +// addPayment saves a successful payment to the database. It is assumed that +// all payment are sent using unique payment hashes. +// +// NOTE: Deprecated. Kept around for migration purposes. +func (db *DB) addPayment(payment *outgoingPayment) error { + // Validate the field of the inner voice within the outgoing payment, + // these must also adhere to the same constraints as regular invoices. + if err := validateInvoice(&payment.Invoice); err != nil { + return err + } + + // We first serialize the payment before starting the database + // transaction so we can avoid creating a DB payment in the case of a + // serialization error. + var b bytes.Buffer + if err := serializeOutgoingPayment(&b, payment); err != nil { + return err + } + paymentBytes := b.Bytes() + + return db.Batch(func(tx *bbolt.Tx) error { + payments, err := tx.CreateBucketIfNotExists(paymentBucket) + if err != nil { + return err + } + + // Obtain the new unique sequence number for this payment. + paymentID, err := payments.NextSequence() + if err != nil { + return err + } + + // We use BigEndian for keys as it orders keys in + // ascending order. This allows bucket scans to order payments + // in the order in which they were created. + paymentIDBytes := make([]byte, 8) + binary.BigEndian.PutUint64(paymentIDBytes, paymentID) + + return payments.Put(paymentIDBytes, paymentBytes) + }) +} + +// fetchAllPayments returns all outgoing payments in DB. +// +// NOTE: Deprecated. Kept around for migration purposes. +func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) { + var payments []*outgoingPayment + + err := db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(paymentBucket) + if bucket == nil { + return ErrNoPaymentsCreated + } + + return bucket.ForEach(func(k, v []byte) error { + // If the value is nil, then we ignore it as it may be + // a sub-bucket. + if v == nil { + return nil + } + + r := bytes.NewReader(v) + payment, err := deserializeOutgoingPayment(r) + if err != nil { + return err + } + + payments = append(payments, payment) + return nil + }) + }) + if err != nil { + return nil, err + } + + return payments, nil +} + +// fetchPaymentStatus returns the payment status for outgoing payment. +// If status of the payment isn't found, it will default to "StatusUnknown". +// +// NOTE: Deprecated. Kept around for migration purposes. +func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) { + var paymentStatus = StatusUnknown + err := db.View(func(tx *bbolt.Tx) error { + var err error + paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash) + return err + }) + if err != nil { + return StatusUnknown, err + } + + return paymentStatus, nil +} + +// fetchPaymentStatusTx is a helper method that returns the payment status for +// outgoing payment. If status of the payment isn't found, it will default to +// "StatusUnknown". It accepts the boltdb transactions such that this method +// can be composed into other atomic operations. +// +// NOTE: Deprecated. Kept around for migration purposes. +func fetchPaymentStatusTx(tx *bbolt.Tx, paymentHash [32]byte) (PaymentStatus, error) { + // The default status for all payments that aren't recorded in database. + var paymentStatus = StatusUnknown + + bucket := tx.Bucket(paymentStatusBucket) + if bucket == nil { + return paymentStatus, nil + } + + paymentStatusBytes := bucket.Get(paymentHash[:]) + if paymentStatusBytes == nil { + return paymentStatus, nil + } + + paymentStatus.FromBytes(paymentStatusBytes) + + return paymentStatus, nil +} + +func serializeOutgoingPayment(w io.Writer, p *outgoingPayment) error { + var scratch [8]byte + + if err := serializeInvoice(w, &p.Invoice); err != nil { + return err + } + + byteOrder.PutUint64(scratch[:], uint64(p.Fee)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + // First write out the length of the bytes to prefix the value. + pathLen := uint32(len(p.Path)) + byteOrder.PutUint32(scratch[:4], pathLen) + if _, err := w.Write(scratch[:4]); err != nil { + return err + } + + // Then with the path written, we write out the series of public keys + // involved in the path. + for _, hop := range p.Path { + if _, err := w.Write(hop[:]); err != nil { + return err + } + } + + byteOrder.PutUint32(scratch[:4], p.TimeLockLength) + if _, err := w.Write(scratch[:4]); err != nil { + return err + } + + if _, err := w.Write(p.PaymentPreimage[:]); err != nil { + return err + } + + return nil +} + +func deserializeOutgoingPayment(r io.Reader) (*outgoingPayment, error) { + var scratch [8]byte + + p := &outgoingPayment{} + + inv, err := deserializeInvoice(r) + if err != nil { + return nil, err + } + p.Invoice = inv + + if _, err := r.Read(scratch[:]); err != nil { + return nil, err + } + p.Fee = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + + if _, err = r.Read(scratch[:4]); err != nil { + return nil, err + } + pathLen := byteOrder.Uint32(scratch[:4]) + + path := make([][33]byte, pathLen) + for i := uint32(0); i < pathLen; i++ { + if _, err := r.Read(path[i][:]); err != nil { + return nil, err + } + } + p.Path = path + + if _, err = r.Read(scratch[:4]); err != nil { + return nil, err + } + p.TimeLockLength = byteOrder.Uint32(scratch[:4]) + + if _, err := r.Read(p.PaymentPreimage[:]); err != nil { + return nil, err + } + + return p, nil +} diff --git a/channeldb/migrations.go b/channeldb/migrations.go index 72ba7882..0a7098c0 100644 --- a/channeldb/migrations.go +++ b/channeldb/migrations.go @@ -6,8 +6,10 @@ import ( "encoding/binary" "fmt" + "github.com/btcsuite/btcd/btcec" "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" ) // migrateNodeAndEdgeUpdateIndex is a migration function that will update the @@ -449,7 +451,7 @@ func paymentStatusesMigration(tx *bbolt.Tx) error { // Update status for current payment to completed. If it fails, // the migration is aborted and the payment bucket is returned // to its previous state. - return paymentStatuses.Put(paymentHash[:], StatusCompleted.Bytes()) + return paymentStatuses.Put(paymentHash[:], StatusSucceeded.Bytes()) }) if err != nil { return err @@ -683,3 +685,190 @@ func migrateGossipMessageStoreKeys(tx *bbolt.Tx) error { return nil } + +// migrateOutgoingPayments moves the OutgoingPayments into a new bucket format +// where they all reside in a top-level bucket indexed by the payment hash. In +// this sub-bucket we store information relevant to this payment, such as the +// payment status. +// +// Since the router cannot handle resumed payments that have the status +// InFlight (we have no PaymentAttemptInfo available for pre-migration +// payments) we delete those statuses, so only Completed payments remain in the +// new bucket structure. +func migrateOutgoingPayments(tx *bbolt.Tx) error { + oldPayments, err := tx.CreateBucketIfNotExists(paymentBucket) + if err != nil { + return err + } + + newPayments, err := tx.CreateBucket(paymentsRootBucket) + if err != nil { + return err + } + + // Get the source pubkey. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + selfPub := nodes.Get(sourceKey) + if selfPub == nil { + return ErrSourceNodeNotSet + } + var sourcePubKey [33]byte + copy(sourcePubKey[:], selfPub[:]) + + log.Infof("Migrating outgoing payments to new bucket structure") + + err = oldPayments.ForEach(func(k, v []byte) error { + // Ignores if it is sub-bucket. + if v == nil { + return nil + } + + // Read the old payment format. + r := bytes.NewReader(v) + payment, err := deserializeOutgoingPayment(r) + if err != nil { + return err + } + + // Calculate payment hash from the payment preimage. + paymentHash := sha256.Sum256(payment.PaymentPreimage[:]) + + // Now create and add a PaymentCreationInfo to the bucket. + c := &PaymentCreationInfo{ + PaymentHash: paymentHash, + Value: payment.Terms.Value, + CreationDate: payment.CreationDate, + PaymentRequest: payment.PaymentRequest, + } + + var infoBuf bytes.Buffer + if err := serializePaymentCreationInfo(&infoBuf, c); err != nil { + return err + } + + // Do the same for the PaymentAttemptInfo. + totalAmt := payment.Terms.Value + payment.Fee + rt := route.Route{ + TotalTimeLock: payment.TimeLockLength, + TotalAmount: totalAmt, + SourcePubKey: sourcePubKey, + Hops: []*route.Hop{}, + } + for _, hop := range payment.Path { + rt.Hops = append(rt.Hops, &route.Hop{ + PubKeyBytes: hop, + AmtToForward: totalAmt, + }) + } + + // Since the old format didn't store the fee for individual + // hops, we let the last hop eat the whole fee for the total to + // add up. + if len(rt.Hops) > 0 { + rt.Hops[len(rt.Hops)-1].AmtToForward = payment.Terms.Value + } + + // Since we don't have the session key for old payments, we + // create a random one to be able to serialize the attempt + // info. + priv, _ := btcec.NewPrivateKey(btcec.S256()) + s := &PaymentAttemptInfo{ + PaymentID: 0, // unknown. + SessionKey: priv, // unknown. + Route: rt, + } + + var attemptBuf bytes.Buffer + if err := serializePaymentAttemptInfo(&attemptBuf, s); err != nil { + return err + } + + // Reuse the existing payment sequence number. + var seqNum [8]byte + copy(seqNum[:], k) + + // Create a bucket indexed by the payment hash. + bucket, err := newPayments.CreateBucket(paymentHash[:]) + + // If the bucket already exists, it means that we are migrating + // from a database containing duplicate payments to a payment + // hash. To keep this information, we store such duplicate + // payments in a sub-bucket. + if err == bbolt.ErrBucketExists { + pHashBucket := newPayments.Bucket(paymentHash[:]) + + // Create a bucket for duplicate payments within this + // payment hash's bucket. + dup, err := pHashBucket.CreateBucketIfNotExists( + paymentDuplicateBucket, + ) + if err != nil { + return err + } + + // Each duplicate will get its own sub-bucket within + // this bucket, so use their sequence number to index + // them by. + bucket, err = dup.CreateBucket(seqNum[:]) + if err != nil { + return err + } + + } else if err != nil { + return err + } + + // Store the payment's information to the bucket. + err = bucket.Put(paymentSequenceKey, seqNum[:]) + if err != nil { + return err + } + + err = bucket.Put(paymentCreationInfoKey, infoBuf.Bytes()) + if err != nil { + return err + } + + err = bucket.Put(paymentAttemptInfoKey, attemptBuf.Bytes()) + if err != nil { + return err + } + + err = bucket.Put(paymentSettleInfoKey, payment.PaymentPreimage[:]) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return err + } + + // To continue producing unique sequence numbers, we set the sequence + // of the new bucket to that of the old one. + seq := oldPayments.Sequence() + if err := newPayments.SetSequence(seq); err != nil { + return err + } + + // Now we delete the old buckets. Deleting the payment status buckets + // deletes all payment statuses other than Complete. + err = tx.DeleteBucket(paymentStatusBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + // Finally delete the old payment bucket. + err = tx.DeleteBucket(paymentBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + log.Infof("Migration of outgoing payment bucket structure completed!") + return nil +} diff --git a/channeldb/migrations_test.go b/channeldb/migrations_test.go index 9223108d..aee98629 100644 --- a/channeldb/migrations_test.go +++ b/channeldb/migrations_test.go @@ -26,11 +26,11 @@ func TestPaymentStatusesMigration(t *testing.T) { // Add fake payment to test database, verifying that it was created, // that we have only one payment, and its status is not "Completed". beforeMigrationFunc := func(d *DB) { - if err := d.AddPayment(fakePayment); err != nil { + if err := d.addPayment(fakePayment); err != nil { t.Fatalf("unable to add payment: %v", err) } - payments, err := d.FetchAllPayments() + payments, err := d.fetchAllPayments() if err != nil { t.Fatalf("unable to fetch payments: %v", err) } @@ -40,15 +40,15 @@ func TestPaymentStatusesMigration(t *testing.T) { len(payments)) } - paymentStatus, err := d.FetchPaymentStatus(paymentHash) + paymentStatus, err := d.fetchPaymentStatus(paymentHash) if err != nil { t.Fatalf("unable to fetch payment status: %v", err) } // We should receive default status if we have any in database. - if paymentStatus != StatusGrounded { + if paymentStatus != StatusUnknown { t.Fatalf("wrong payment status: expected %v, got %v", - StatusGrounded.String(), paymentStatus.String()) + StatusUnknown.String(), paymentStatus.String()) } // Lastly, we'll add a locally-sourced circuit and @@ -141,14 +141,14 @@ func TestPaymentStatusesMigration(t *testing.T) { } // Check that our completed payments were migrated. - paymentStatus, err := d.FetchPaymentStatus(paymentHash) + paymentStatus, err := d.fetchPaymentStatus(paymentHash) if err != nil { t.Fatalf("unable to fetch payment status: %v", err) } - if paymentStatus != StatusCompleted { + if paymentStatus != StatusSucceeded { t.Fatalf("wrong payment status: expected %v, got %v", - StatusCompleted.String(), paymentStatus.String()) + StatusSucceeded.String(), paymentStatus.String()) } inFlightHash := [32]byte{ @@ -160,7 +160,7 @@ func TestPaymentStatusesMigration(t *testing.T) { // Check that the locally sourced payment was transitioned to // InFlight. - paymentStatus, err = d.FetchPaymentStatus(inFlightHash) + paymentStatus, err = d.fetchPaymentStatus(inFlightHash) if err != nil { t.Fatalf("unable to fetch payment status: %v", err) } @@ -179,14 +179,14 @@ func TestPaymentStatusesMigration(t *testing.T) { // Check that non-locally sourced payments remain in the default // Grounded state. - paymentStatus, err = d.FetchPaymentStatus(groundedHash) + paymentStatus, err = d.fetchPaymentStatus(groundedHash) if err != nil { t.Fatalf("unable to fetch payment status: %v", err) } - if paymentStatus != StatusGrounded { + if paymentStatus != StatusUnknown { t.Fatalf("wrong payment status: expected %v, got %v", - StatusGrounded.String(), paymentStatus.String()) + StatusUnknown.String(), paymentStatus.String()) } } @@ -564,3 +564,162 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) { migrateGossipMessageStoreKeys, false, ) } + +// TestOutgoingPaymentsMigration checks that OutgoingPayments are migrated to a +// new bucket structure after the migration. +func TestOutgoingPaymentsMigration(t *testing.T) { + t.Parallel() + + const numPayments = 4 + var oldPayments []*outgoingPayment + + // Add fake payments to test database, verifying that it was created. + beforeMigrationFunc := func(d *DB) { + for i := 0; i < numPayments; i++ { + var p *outgoingPayment + var err error + + // We fill the database with random payments. For the + // very last one we'll use a duplicate of the first, to + // ensure we are able to handle migration from a + // database that has copies. + if i < numPayments-1 { + p, err = makeRandomFakePayment() + if err != nil { + t.Fatalf("unable to create payment: %v", + err) + } + } else { + p = oldPayments[0] + } + + if err := d.addPayment(p); err != nil { + t.Fatalf("unable to add payment: %v", err) + } + + oldPayments = append(oldPayments, p) + } + + payments, err := d.fetchAllPayments() + if err != nil { + t.Fatalf("unable to fetch payments: %v", err) + } + + if len(payments) != numPayments { + t.Fatalf("wrong qty of paymets: expected %d got %v", + numPayments, len(payments)) + } + } + + // Verify that all payments were migrated. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration 'paymentStatusesMigration' wasn't applied") + } + + sentPayments, err := d.FetchPayments() + if err != nil { + t.Fatalf("unable to fetch sent payments: %v", err) + } + + if len(sentPayments) != numPayments { + t.Fatalf("expected %d payments, got %d", numPayments, + len(sentPayments)) + } + + graph := d.ChannelGraph() + sourceNode, err := graph.SourceNode() + if err != nil { + t.Fatalf("unable to fetch source node: %v", err) + } + + for i, p := range sentPayments { + // The payment status should be Completed. + if p.Status != StatusSucceeded { + t.Fatalf("expected Completed, got %v", p.Status) + } + + // Check that the sequence number is preserved. They + // start counting at 1. + if p.sequenceNum != uint64(i+1) { + t.Fatalf("expected seqnum %d, got %d", i, + p.sequenceNum) + } + + // Order of payments should be be preserved. + old := oldPayments[i] + + // Check the individial fields. + if p.Info.Value != old.Terms.Value { + t.Fatalf("value mismatch") + } + + if p.Info.CreationDate != old.CreationDate { + t.Fatalf("date mismatch") + } + + if !bytes.Equal(p.Info.PaymentRequest, old.PaymentRequest) { + t.Fatalf("payreq mismatch") + } + + if *p.PaymentPreimage != old.PaymentPreimage { + t.Fatalf("preimage mismatch") + } + + if p.Attempt.Route.TotalFees() != old.Fee { + t.Fatalf("Fee mismatch") + } + + if p.Attempt.Route.TotalAmount != old.Fee+old.Terms.Value { + t.Fatalf("Total amount mismatch") + } + + if p.Attempt.Route.TotalTimeLock != old.TimeLockLength { + t.Fatalf("timelock mismatch") + } + + if p.Attempt.Route.SourcePubKey != sourceNode.PubKeyBytes { + t.Fatalf("source mismatch: %x vs %x", + p.Attempt.Route.SourcePubKey[:], + sourceNode.PubKeyBytes[:]) + } + + for i, hop := range old.Path { + if hop != p.Attempt.Route.Hops[i].PubKeyBytes { + t.Fatalf("path mismatch") + } + } + } + + // Finally, check that the payment sequence number is updated + // to reflect the migrated payments. + err = d.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return fmt.Errorf("payments bucket not found") + } + + seq := payments.Sequence() + if seq != numPayments { + return fmt.Errorf("expected sequence to be "+ + "%d, got %d", numPayments, seq) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrateOutgoingPayments, + false) +} diff --git a/channeldb/payments.go b/channeldb/payments.go index 08c9e02f..9eabe37d 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -4,43 +4,118 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" + "sort" + "time" + "github.com/btcsuite/btcd/btcec" "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) var ( - // paymentBucket is the name of the bucket within the database that - // stores all data related to payments. + // paymentsRootBucket is the name of the top-level bucket within the + // database that stores all data related to payments. Within this + // bucket, each payment hash its own sub-bucket keyed by its payment + // hash. // - // Within the payments bucket, each invoice is keyed by its invoice ID - // which is a monotonically increasing uint64. BoltDB's sequence - // feature is used for generating monotonically increasing id. - paymentBucket = []byte("payments") + // Bucket hierarchy: + // + // root-bucket + // | + // |-- + // | |--sequence-key: + // | |--creation-info-key: + // | |--attempt-info-key: + // | |--settle-info-key: + // | |--fail-info-key: + // | | + // | |--duplicate-bucket (only for old, completed payments) + // | | + // | |-- + // | | |--sequence-key: + // | | |--creation-info-key: + // | | |--attempt-info-key: + // | | |--settle-info-key: + // | | |--fail-info-key: + // | | + // | |-- + // | | | + // | ... ... + // | + // |-- + // | | + // | ... + // ... + // + paymentsRootBucket = []byte("payments-root-bucket") - // paymentStatusBucket is the name of the bucket within the database that - // stores the status of a payment indexed by the payment's preimage. - paymentStatusBucket = []byte("payment-status") + // paymentDublicateBucket is the name of a optional sub-bucket within + // the payment hash bucket, that is used to hold duplicate payments to + // a payment hash. This is needed to support information from earlier + // versions of lnd, where it was possible to pay to a payment hash more + // than once. + paymentDuplicateBucket = []byte("payment-duplicate-bucket") + + // paymentSequenceKey is a key used in the payment's sub-bucket to + // store the sequence number of the payment. + paymentSequenceKey = []byte("payment-sequence-key") + + // paymentCreationInfoKey is a key used in the payment's sub-bucket to + // store the creation info of the payment. + paymentCreationInfoKey = []byte("payment-creation-info") + + // paymentAttemptInfoKey is a key used in the payment's sub-bucket to + // store the info about the latest attempt that was done for the + // payment in question. + paymentAttemptInfoKey = []byte("payment-attempt-info") + + // paymentSettleInfoKey is a key used in the payment's sub-bucket to + // store the settle info of the payment. + paymentSettleInfoKey = []byte("payment-settle-info") + + // paymentFailInfoKey is a key used in the payment's sub-bucket to + // store information about the reason a payment failed. + paymentFailInfoKey = []byte("payment-fail-info") +) + +// FailureReason encodes the reason a payment ultimately failed. +type FailureReason byte + +const ( + // FailureReasonTimeout indicates that the payment did timeout before a + // successful payment attempt was made. + FailureReasonTimeout FailureReason = 0 + + // FailureReasonNoRoute indicates no successful route to the + // destination was found during path finding. + FailureReasonNoRoute FailureReason = 1 + + // TODO(halseth): cancel state. ) // PaymentStatus represent current status of payment type PaymentStatus byte const ( - // StatusGrounded is the status where a payment has never been - // initiated, or has been initiated and received an intermittent - // failure. - StatusGrounded PaymentStatus = 0 + // StatusUnknown is the status where a payment has never been initiated + // and hence is unknown. + StatusUnknown PaymentStatus = 0 // StatusInFlight is the status where a payment has been initiated, but // a response has not been received. StatusInFlight PaymentStatus = 1 - // StatusCompleted is the status where a payment has been initiated and + // StatusSucceeded is the status where a payment has been initiated and // the payment was completed successfully. - StatusCompleted PaymentStatus = 2 + StatusSucceeded PaymentStatus = 2 + + // StatusFailed is the status where a payment has been initiated and a + // failure result has come back. + StatusFailed PaymentStatus = 3 ) // Bytes returns status as slice of bytes. @@ -55,7 +130,7 @@ func (ps *PaymentStatus) FromBytes(status []byte) error { } switch PaymentStatus(status[0]) { - case StatusGrounded, StatusInFlight, StatusCompleted: + case StatusUnknown, StatusInFlight, StatusSucceeded, StatusFailed: *ps = PaymentStatus(status[0]) default: return errors.New("unknown payment status") @@ -67,267 +142,339 @@ func (ps *PaymentStatus) FromBytes(status []byte) error { // String returns readable representation of payment status. func (ps PaymentStatus) String() string { switch ps { - case StatusGrounded: - return "Grounded" + case StatusUnknown: + return "Unknown" case StatusInFlight: return "In Flight" - case StatusCompleted: - return "Completed" + case StatusSucceeded: + return "Succeeded" + case StatusFailed: + return "Failed" default: return "Unknown" } } -// OutgoingPayment represents a successful payment between the daemon and a -// remote node. Details such as the total fee paid, and the time of the payment -// are stored. -type OutgoingPayment struct { - Invoice +// PaymentCreationInfo is the information necessary to have ready when +// initiating a payment, moving it into state InFlight. +type PaymentCreationInfo struct { + // PaymentHash is the hash this payment is paying to. + PaymentHash lntypes.Hash - // Fee is the total fee paid for the payment in milli-satoshis. - Fee lnwire.MilliSatoshi + // Value is the amount we are paying. + Value lnwire.MilliSatoshi - // TotalTimeLock is the total cumulative time-lock in the HTLC extended - // from the second-to-last hop to the destination. - TimeLockLength uint32 + // CreatingDate is the time when this payment was initiated. + CreationDate time.Time - // Path encodes the path the payment took through the network. The path - // excludes the outgoing node and consists of the hex-encoded - // compressed public key of each of the nodes involved in the payment. - Path [][33]byte - - // PaymentPreimage is the preImage of a successful payment. This is used - // to calculate the PaymentHash as well as serve as a proof of payment. - PaymentPreimage [32]byte + // PaymentRequest is the full payment request, if any. + PaymentRequest []byte } -// AddPayment saves a successful payment to the database. It is assumed that -// all payment are sent using unique payment hashes. -func (db *DB) AddPayment(payment *OutgoingPayment) error { - // Validate the field of the inner voice within the outgoing payment, - // these must also adhere to the same constraints as regular invoices. - if err := validateInvoice(&payment.Invoice); err != nil { - return err - } +// PaymentAttemptInfo contains information about a specific payment attempt for +// a given payment. This information is used by the router to handle any errors +// coming back after an attempt is made, and to query the switch about the +// status of a payment. For settled payment this will be the information for +// the succeeding payment attempt. +type PaymentAttemptInfo struct { + // PaymentID is the unique ID used for this attempt. + PaymentID uint64 - // We first serialize the payment before starting the database - // transaction so we can avoid creating a DB payment in the case of a - // serialization error. - var b bytes.Buffer - if err := serializeOutgoingPayment(&b, payment); err != nil { - return err - } - paymentBytes := b.Bytes() + // SessionKey is the ephemeral key used for this payment attempt. + SessionKey *btcec.PrivateKey - return db.Batch(func(tx *bbolt.Tx) error { - payments, err := tx.CreateBucketIfNotExists(paymentBucket) - if err != nil { - return err - } - - // Obtain the new unique sequence number for this payment. - paymentID, err := payments.NextSequence() - if err != nil { - return err - } - - // We use BigEndian for keys as it orders keys in - // ascending order. This allows bucket scans to order payments - // in the order in which they were created. - paymentIDBytes := make([]byte, 8) - binary.BigEndian.PutUint64(paymentIDBytes, paymentID) - - return payments.Put(paymentIDBytes, paymentBytes) - }) + // Route is the route attempted to send the HTLC. + Route route.Route } -// FetchAllPayments returns all outgoing payments in DB. -func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) { - var payments []*OutgoingPayment +// Payment is a wrapper around a payment's PaymentCreationInfo, +// PaymentAttemptInfo, and preimage. All payments will have the +// PaymentCreationInfo set, the PaymentAttemptInfo will be set only if at least +// one payment attempt has been made, while only completed payments will have a +// non-zero payment preimage. +type Payment struct { + // sequenceNum is a unique identifier used to sort the payments in + // order of creation. + sequenceNum uint64 + + // Status is the current PaymentStatus of this payment. + Status PaymentStatus + + // Info holds all static information about this payment, and is + // populated when the payment is initiated. + Info *PaymentCreationInfo + + // Attempt is the information about the last payment attempt made. + // + // NOTE: Can be nil if no attempt is yet made. + Attempt *PaymentAttemptInfo + + // PaymentPreimage is the preimage of a successful payment. This serves + // as a proof of payment. It will only be non-nil for settled payments. + // + // NOTE: Can be nil if payment is not settled. + PaymentPreimage *lntypes.Preimage + + // Failure is a failure reason code indicating the reason the payment + // failed. It is only non-nil for failed payments. + // + // NOTE: Can be nil if payment is not failed. + Failure *FailureReason +} + +// FetchPayments returns all sent payments found in the DB. +func (db *DB) FetchPayments() ([]*Payment, error) { + var payments []*Payment err := db.View(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(paymentBucket) - if bucket == nil { - return ErrNoPaymentsCreated + paymentsBucket := tx.Bucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil } - return bucket.ForEach(func(k, v []byte) error { - // If the value is nil, then we ignore it as it may be - // a sub-bucket. - if v == nil { - return nil + return paymentsBucket.ForEach(func(k, v []byte) error { + bucket := paymentsBucket.Bucket(k) + if bucket == nil { + // We only expect sub-buckets to be found in + // this top-level bucket. + return fmt.Errorf("non bucket element in " + + "payments bucket") } - r := bytes.NewReader(v) - payment, err := deserializeOutgoingPayment(r) + p, err := fetchPayment(bucket) if err != nil { return err } - payments = append(payments, payment) - return nil + payments = append(payments, p) + + // For older versions of lnd, duplicate payments to a + // payment has was possible. These will be found in a + // sub-bucket indexed by their sequence number if + // available. + dup := bucket.Bucket(paymentDuplicateBucket) + if dup == nil { + return nil + } + + return dup.ForEach(func(k, v []byte) error { + subBucket := dup.Bucket(k) + if subBucket == nil { + // We one bucket for each duplicate to + // be found. + return fmt.Errorf("non bucket element" + + "in duplicate bucket") + } + + p, err := fetchPayment(subBucket) + if err != nil { + return err + } + + payments = append(payments, p) + return nil + }) }) }) if err != nil { return nil, err } + // Before returning, sort the payments by their sequence number. + sort.Slice(payments, func(i, j int) bool { + return payments[i].sequenceNum < payments[j].sequenceNum + }) + return payments, nil } -// DeleteAllPayments deletes all payments from DB. -func (db *DB) DeleteAllPayments() error { +func fetchPayment(bucket *bbolt.Bucket) (*Payment, error) { + var ( + err error + p = &Payment{} + ) + + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return nil, fmt.Errorf("sequence number not found") + } + + p.sequenceNum = binary.BigEndian.Uint64(seqBytes) + + // Get the payment status. + p.Status = fetchPaymentStatus(bucket) + + // Get the PaymentCreationInfo. + b := bucket.Get(paymentCreationInfoKey) + if b == nil { + return nil, fmt.Errorf("creation info not found") + } + + r := bytes.NewReader(b) + p.Info, err = deserializePaymentCreationInfo(r) + if err != nil { + return nil, err + + } + + // Get the PaymentAttemptInfo. This can be unset. + b = bucket.Get(paymentAttemptInfoKey) + if b != nil { + r = bytes.NewReader(b) + p.Attempt, err = deserializePaymentAttemptInfo(r) + if err != nil { + return nil, err + } + } + + // Get the payment preimage. This is only found for + // completed payments. + b = bucket.Get(paymentSettleInfoKey) + if b != nil { + var preimg lntypes.Preimage + copy(preimg[:], b[:]) + p.PaymentPreimage = &preimg + } + + // Get failure reason if available. + b = bucket.Get(paymentFailInfoKey) + if b != nil { + reason := FailureReason(b[0]) + p.Failure = &reason + } + + return p, nil +} + +// DeletePayments deletes all completed and failed payments from the DB. +func (db *DB) DeletePayments() error { return db.Update(func(tx *bbolt.Tx) error { - err := tx.DeleteBucket(paymentBucket) - if err != nil && err != bbolt.ErrBucketNotFound { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil + } + + var deleteBuckets [][]byte + err := payments.ForEach(func(k, _ []byte) error { + bucket := payments.Bucket(k) + if bucket == nil { + // We only expect sub-buckets to be found in + // this top-level bucket. + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + + // If the status is InFlight, we cannot safely delete + // the payment information, so we return early. + paymentStatus := fetchPaymentStatus(bucket) + if paymentStatus == StatusInFlight { + return nil + } + + deleteBuckets = append(deleteBuckets, k) + return nil + }) + if err != nil { return err } - _, err = tx.CreateBucket(paymentBucket) - return err + for _, k := range deleteBuckets { + if err := payments.DeleteBucket(k); err != nil { + return err + } + } + + return nil }) } -// UpdatePaymentStatus sets the payment status for outgoing/finished payments in -// local database. -func (db *DB) UpdatePaymentStatus(paymentHash [32]byte, status PaymentStatus) error { - return db.Batch(func(tx *bbolt.Tx) error { - return UpdatePaymentStatusTx(tx, paymentHash, status) - }) -} - -// UpdatePaymentStatusTx is a helper method that sets the payment status for -// outgoing/finished payments in the local database. This method accepts a -// boltdb transaction such that the operation can be composed into other -// database transactions. -func UpdatePaymentStatusTx(tx *bbolt.Tx, - paymentHash [32]byte, status PaymentStatus) error { - - paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket) - if err != nil { - return err - } - - return paymentStatuses.Put(paymentHash[:], status.Bytes()) -} - -// FetchPaymentStatus returns the payment status for outgoing payment. -// If status of the payment isn't found, it will default to "StatusGrounded". -func (db *DB) FetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) { - var paymentStatus = StatusGrounded - err := db.View(func(tx *bbolt.Tx) error { - var err error - paymentStatus, err = FetchPaymentStatusTx(tx, paymentHash) - return err - }) - if err != nil { - return StatusGrounded, err - } - - return paymentStatus, nil -} - -// FetchPaymentStatusTx is a helper method that returns the payment status for -// outgoing payment. If status of the payment isn't found, it will default to -// "StatusGrounded". It accepts the boltdb transactions such that this method -// can be composed into other atomic operations. -func FetchPaymentStatusTx(tx *bbolt.Tx, paymentHash [32]byte) (PaymentStatus, error) { - // The default status for all payments that aren't recorded in database. - var paymentStatus = StatusGrounded - - bucket := tx.Bucket(paymentStatusBucket) - if bucket == nil { - return paymentStatus, nil - } - - paymentStatusBytes := bucket.Get(paymentHash[:]) - if paymentStatusBytes == nil { - return paymentStatus, nil - } - - paymentStatus.FromBytes(paymentStatusBytes) - - return paymentStatus, nil -} - -func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error { +func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { var scratch [8]byte - if err := serializeInvoice(w, &p.Invoice); err != nil { + if _, err := w.Write(c.PaymentHash[:]); err != nil { return err } - byteOrder.PutUint64(scratch[:], uint64(p.Fee)) + byteOrder.PutUint64(scratch[:], uint64(c.Value)) if _, err := w.Write(scratch[:]); err != nil { return err } - // First write out the length of the bytes to prefix the value. - pathLen := uint32(len(p.Path)) - byteOrder.PutUint32(scratch[:4], pathLen) + byteOrder.PutUint64(scratch[:], uint64(c.CreationDate.Unix())) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + byteOrder.PutUint32(scratch[:4], uint32(len(c.PaymentRequest))) if _, err := w.Write(scratch[:4]); err != nil { return err } - // Then with the path written, we write out the series of public keys - // involved in the path. - for _, hop := range p.Path { - if _, err := w.Write(hop[:]); err != nil { - return err - } - } - - byteOrder.PutUint32(scratch[:4], p.TimeLockLength) - if _, err := w.Write(scratch[:4]); err != nil { - return err - } - - if _, err := w.Write(p.PaymentPreimage[:]); err != nil { + if _, err := w.Write(c.PaymentRequest[:]); err != nil { return err } return nil } -func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) { +func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) { var scratch [8]byte - p := &OutgoingPayment{} + c := &PaymentCreationInfo{} - inv, err := deserializeInvoice(r) - if err != nil { + if _, err := io.ReadFull(r, c.PaymentHash[:]); err != nil { return nil, err } - p.Invoice = inv - if _, err := r.Read(scratch[:]); err != nil { + if _, err := io.ReadFull(r, scratch[:]); err != nil { return nil, err } - p.Fee = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + c.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) - if _, err = r.Read(scratch[:4]); err != nil { + if _, err := io.ReadFull(r, scratch[:]); err != nil { return nil, err } - pathLen := byteOrder.Uint32(scratch[:4]) + c.CreationDate = time.Unix(int64(byteOrder.Uint64(scratch[:])), 0) - path := make([][33]byte, pathLen) - for i := uint32(0); i < pathLen; i++ { - if _, err := r.Read(path[i][:]); err != nil { + if _, err := io.ReadFull(r, scratch[:4]); err != nil { + return nil, err + } + + reqLen := uint32(byteOrder.Uint32(scratch[:4])) + payReq := make([]byte, reqLen) + if reqLen > 0 { + if _, err := io.ReadFull(r, payReq[:]); err != nil { return nil, err } } - p.Path = path + c.PaymentRequest = payReq - if _, err = r.Read(scratch[:4]); err != nil { - return nil, err - } - p.TimeLockLength = byteOrder.Uint32(scratch[:4]) + return c, nil +} - if _, err := r.Read(p.PaymentPreimage[:]); err != nil { - return nil, err +func serializePaymentAttemptInfo(w io.Writer, a *PaymentAttemptInfo) error { + if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil { + return err } - return p, nil + if err := serializeRoute(w, a.Route); err != nil { + return err + } + + return nil +} + +func deserializePaymentAttemptInfo(r io.Reader) (*PaymentAttemptInfo, error) { + a := &PaymentAttemptInfo{} + err := ReadElements(r, &a.PaymentID, &a.SessionKey) + if err != nil { + return nil, err + } + a.Route, err = deserializeRoute(r) + if err != nil { + return nil, err + } + return a, nil } func serializeHop(w io.Writer, h *route.Hop) error { diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 80a61be2..2be1f38b 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -10,6 +10,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -36,7 +37,7 @@ var ( } ) -func makeFakePayment() *OutgoingPayment { +func makeFakePayment() *outgoingPayment { fakeInvoice := &Invoice{ // Use single second precision to avoid false positive test // failures due to the monotonic time component. @@ -54,7 +55,7 @@ func makeFakePayment() *OutgoingPayment { copy(fakePath[i][:], bytes.Repeat([]byte{byte(i)}, 33)) } - fakePayment := &OutgoingPayment{ + fakePayment := &outgoingPayment{ Invoice: *fakeInvoice, Fee: 101, Path: fakePath, @@ -64,6 +65,27 @@ func makeFakePayment() *OutgoingPayment { return fakePayment } +func makeFakeInfo() (*PaymentCreationInfo, *PaymentAttemptInfo) { + var preimg lntypes.Preimage + copy(preimg[:], rev[:]) + + c := &PaymentCreationInfo{ + PaymentHash: preimg.Hash(), + Value: 1000, + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte(""), + } + + a := &PaymentAttemptInfo{ + PaymentID: 44, + SessionKey: priv, + Route: testRoute, + } + return c, a +} + func makeFakePaymentHash() [32]byte { var paymentHash [32]byte rBytes, _ := randomBytes(0, 32) @@ -84,7 +106,7 @@ func randomBytes(minLen, maxLen int) ([]byte, error) { return randBuf, nil } -func makeRandomFakePayment() (*OutgoingPayment, error) { +func makeRandomFakePayment() (*outgoingPayment, error) { var err error fakeInvoice := &Invoice{ // Use single second precision to avoid false positive test @@ -102,7 +124,10 @@ func makeRandomFakePayment() (*OutgoingPayment, error) { return nil, err } - fakeInvoice.PaymentRequest = []byte("") + fakeInvoice.PaymentRequest, err = randomBytes(1, 50) + if err != nil { + return nil, err + } preImg, err := randomBytes(32, 33) if err != nil { @@ -122,7 +147,7 @@ func makeRandomFakePayment() (*OutgoingPayment, error) { copy(fakePath[i][:], b) } - fakePayment := &OutgoingPayment{ + fakePayment := &outgoingPayment{ Invoice: *fakeInvoice, Fee: lnwire.MilliSatoshi(rand.Intn(1001)), Path: fakePath, @@ -133,147 +158,45 @@ func makeRandomFakePayment() (*OutgoingPayment, error) { return fakePayment, nil } -func TestOutgoingPaymentSerialization(t *testing.T) { +func TestSentPaymentSerialization(t *testing.T) { t.Parallel() - fakePayment := makeFakePayment() + c, s := makeFakeInfo() var b bytes.Buffer - if err := serializeOutgoingPayment(&b, fakePayment); err != nil { - t.Fatalf("unable to serialize outgoing payment: %v", err) + if err := serializePaymentCreationInfo(&b, c); err != nil { + t.Fatalf("unable to serialize creation info: %v", err) } - newPayment, err := deserializeOutgoingPayment(&b) + newCreationInfo, err := deserializePaymentCreationInfo(&b) if err != nil { - t.Fatalf("unable to deserialize outgoing payment: %v", err) + t.Fatalf("unable to deserialize creation info: %v", err) } - if !reflect.DeepEqual(fakePayment, newPayment) { + if !reflect.DeepEqual(c, newCreationInfo) { t.Fatalf("Payments do not match after "+ "serialization/deserialization %v vs %v", - spew.Sdump(fakePayment), - spew.Sdump(newPayment), - ) - } -} - -func TestOutgoingPaymentWorkflow(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - fakePayment := makeFakePayment() - if err = db.AddPayment(fakePayment); err != nil { - t.Fatalf("unable to put payment in DB: %v", err) - } - - payments, err := db.FetchAllPayments() - if err != nil { - t.Fatalf("unable to fetch payments from DB: %v", err) - } - - expectedPayments := []*OutgoingPayment{fakePayment} - if !reflect.DeepEqual(payments, expectedPayments) { - t.Fatalf("Wrong payments after reading from DB."+ - "Got %v, want %v", - spew.Sdump(payments), - spew.Sdump(expectedPayments), + spew.Sdump(c), spew.Sdump(newCreationInfo), ) } - // Make some random payments - for i := 0; i < 5; i++ { - randomPayment, err := makeRandomFakePayment() - if err != nil { - t.Fatalf("Internal error in tests: %v", err) - } - - if err = db.AddPayment(randomPayment); err != nil { - t.Fatalf("unable to put payment in DB: %v", err) - } - - expectedPayments = append(expectedPayments, randomPayment) + b.Reset() + if err := serializePaymentAttemptInfo(&b, s); err != nil { + t.Fatalf("unable to serialize info: %v", err) } - payments, err = db.FetchAllPayments() + newAttemptInfo, err := deserializePaymentAttemptInfo(&b) if err != nil { - t.Fatalf("Can't get payments from DB: %v", err) + t.Fatalf("unable to deserialize info: %v", err) } - if !reflect.DeepEqual(payments, expectedPayments) { - t.Fatalf("Wrong payments after reading from DB."+ - "Got %v, want %v", - spew.Sdump(payments), - spew.Sdump(expectedPayments), + if !reflect.DeepEqual(s, newAttemptInfo) { + t.Fatalf("Payments do not match after "+ + "serialization/deserialization %v vs %v", + spew.Sdump(s), spew.Sdump(newAttemptInfo), ) } - // Delete all payments. - if err = db.DeleteAllPayments(); err != nil { - t.Fatalf("unable to delete payments from DB: %v", err) - } - - // Check that there is no payments after deletion - paymentsAfterDeletion, err := db.FetchAllPayments() - if err != nil { - t.Fatalf("Can't get payments after deletion: %v", err) - } - if len(paymentsAfterDeletion) != 0 { - t.Fatalf("After deletion DB has %v payments, want %v", - len(paymentsAfterDeletion), 0) - } -} - -func TestPaymentStatusWorkflow(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - testCases := []struct { - paymentHash [32]byte - status PaymentStatus - }{ - { - paymentHash: makeFakePaymentHash(), - status: StatusGrounded, - }, - { - paymentHash: makeFakePaymentHash(), - status: StatusInFlight, - }, - { - paymentHash: makeFakePaymentHash(), - status: StatusCompleted, - }, - } - - for _, testCase := range testCases { - err := db.UpdatePaymentStatus(testCase.paymentHash, testCase.status) - if err != nil { - t.Fatalf("unable to put payment in DB: %v", err) - } - - status, err := db.FetchPaymentStatus(testCase.paymentHash) - if err != nil { - t.Fatalf("unable to fetch payments from DB: %v", err) - } - - if status != testCase.status { - t.Fatalf("Wrong payments status after reading from DB."+ - "Got %v, want %v", - spew.Sdump(status), - spew.Sdump(testCase.status), - ) - } - } } func TestRouteSerialization(t *testing.T) { diff --git a/htlcswitch/control_tower.go b/htlcswitch/control_tower.go deleted file mode 100644 index 380fb787..00000000 --- a/htlcswitch/control_tower.go +++ /dev/null @@ -1,245 +0,0 @@ -package htlcswitch - -import ( - "errors" - - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/lnwire" -) - -var ( - // ErrAlreadyPaid signals we have already paid this payment hash. - ErrAlreadyPaid = errors.New("invoice is already paid") - - // ErrPaymentInFlight signals that payment for this payment hash is - // already "in flight" on the network. - ErrPaymentInFlight = errors.New("payment is in transition") - - // ErrPaymentNotInitiated is returned if payment wasn't initiated in - // switch. - ErrPaymentNotInitiated = errors.New("payment isn't initiated") - - // ErrPaymentAlreadyCompleted is returned in the event we attempt to - // recomplete a completed payment. - ErrPaymentAlreadyCompleted = errors.New("payment is already completed") - - // ErrUnknownPaymentStatus is returned when we do not recognize the - // existing state of a payment. - ErrUnknownPaymentStatus = errors.New("unknown payment status") -) - -// ControlTower tracks all outgoing payments made by the switch, whose primary -// purpose is to prevent duplicate payments to the same payment hash. In -// production, a persistent implementation is preferred so that tracking can -// survive across restarts. Payments are transition through various payment -// states, and the ControlTower interface provides access to driving the state -// transitions. -type ControlTower interface { - // ClearForTakeoff atomically checks that no inflight or completed - // payments exist for this payment hash. If none are found, this method - // atomically transitions the status for this payment hash as InFlight. - ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error - - // Success transitions an InFlight payment into a Completed payment. - // After invoking this method, ClearForTakeoff should always return an - // error to prevent us from making duplicate payments to the same - // payment hash. - Success(paymentHash [32]byte) error - - // Fail transitions an InFlight payment into a Grounded Payment. After - // invoking this method, ClearForTakeoff should return nil on its next - // call for this payment hash, allowing the switch to make a subsequent - // payment. - Fail(paymentHash [32]byte) error -} - -// paymentControl is persistent implementation of ControlTower to restrict -// double payment sending. -type paymentControl struct { - strict bool - - db *channeldb.DB -} - -// NewPaymentControl creates a new instance of the paymentControl. The strict -// flag indicates whether the controller should require "strict" state -// transitions, which would be otherwise intolerant to older databases that may -// already have duplicate payments to the same payment hash. It should be -// enabled only after sufficient checks have been made to ensure the db does not -// contain such payments. In the meantime, non-strict mode enforces a superset -// of the state transitions that prevent additional payments to a given payment -// hash from being added. -func NewPaymentControl(strict bool, db *channeldb.DB) ControlTower { - return &paymentControl{ - strict: strict, - db: db, - } -} - -// ClearForTakeoff checks that we don't already have an InFlight or Completed -// payment identified by the same payment hash. -func (p *paymentControl) ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error { - var takeoffErr error - err := p.db.Batch(func(tx *bbolt.Tx) error { - // Retrieve current status of payment from local database. - paymentStatus, err := channeldb.FetchPaymentStatusTx( - tx, htlc.PaymentHash, - ) - if err != nil { - return err - } - - // Reset the takeoff error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - takeoffErr = nil - - switch paymentStatus { - - case channeldb.StatusGrounded: - // It is safe to reattempt a payment if we know that we - // haven't left one in flight. Since this one is - // grounded, Transition the payment status to InFlight - // to prevent others. - return channeldb.UpdatePaymentStatusTx( - tx, htlc.PaymentHash, channeldb.StatusInFlight, - ) - - case channeldb.StatusInFlight: - // We already have an InFlight payment on the network. We will - // disallow any more payment until a response is received. - takeoffErr = ErrPaymentInFlight - - case channeldb.StatusCompleted: - // We've already completed a payment to this payment hash, - // forbid the switch from sending another. - takeoffErr = ErrAlreadyPaid - - default: - takeoffErr = ErrUnknownPaymentStatus - } - - return nil - }) - if err != nil { - return err - } - - return takeoffErr -} - -// Success transitions an InFlight payment to Completed, otherwise it returns an -// error. After calling Success, ClearForTakeoff should prevent any further -// attempts for the same payment hash. -func (p *paymentControl) Success(paymentHash [32]byte) error { - var updateErr error - err := p.db.Batch(func(tx *bbolt.Tx) error { - paymentStatus, err := channeldb.FetchPaymentStatusTx( - tx, paymentHash, - ) - if err != nil { - return err - } - - // Reset the update error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - updateErr = nil - - switch { - - case paymentStatus == channeldb.StatusGrounded && p.strict: - // Our records show the payment as still being grounded, - // meaning it never should have left the switch. - updateErr = ErrPaymentNotInitiated - - case paymentStatus == channeldb.StatusGrounded && !p.strict: - // Though our records show the payment as still being - // grounded, meaning it never should have left the - // switch, we permit this transition in non-strict mode - // to handle inconsistent db states. - fallthrough - - case paymentStatus == channeldb.StatusInFlight: - // A successful response was received for an InFlight - // payment, mark it as completed to prevent sending to - // this payment hash again. - return channeldb.UpdatePaymentStatusTx( - tx, paymentHash, channeldb.StatusCompleted, - ) - - case paymentStatus == channeldb.StatusCompleted: - // The payment was completed previously, alert the - // caller that this may be a duplicate call. - updateErr = ErrPaymentAlreadyCompleted - - default: - updateErr = ErrUnknownPaymentStatus - } - - return nil - }) - if err != nil { - return err - } - - return updateErr -} - -// Fail transitions an InFlight payment to Grounded, otherwise it returns an -// error. After calling Fail, ClearForTakeoff should fail any further attempts -// for the same payment hash. -func (p *paymentControl) Fail(paymentHash [32]byte) error { - var updateErr error - err := p.db.Batch(func(tx *bbolt.Tx) error { - paymentStatus, err := channeldb.FetchPaymentStatusTx( - tx, paymentHash, - ) - if err != nil { - return err - } - - // Reset the update error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - updateErr = nil - - switch { - - case paymentStatus == channeldb.StatusGrounded && p.strict: - // Our records show the payment as still being grounded, - // meaning it never should have left the switch. - updateErr = ErrPaymentNotInitiated - - case paymentStatus == channeldb.StatusGrounded && !p.strict: - // Though our records show the payment as still being - // grounded, meaning it never should have left the - // switch, we permit this transition in non-strict mode - // to handle inconsistent db states. - fallthrough - - case paymentStatus == channeldb.StatusInFlight: - // A failed response was received for an InFlight - // payment, mark it as Grounded again to allow - // subsequent attempts. - return channeldb.UpdatePaymentStatusTx( - tx, paymentHash, channeldb.StatusGrounded, - ) - - case paymentStatus == channeldb.StatusCompleted: - // The payment was completed previously, and we are now - // reporting that it has failed. Leave the status as - // completed, but alert the user that something is - // wrong. - updateErr = ErrPaymentAlreadyCompleted - - default: - updateErr = ErrUnknownPaymentStatus - } - - return nil - }) - if err != nil { - return err - } - - return updateErr -} diff --git a/htlcswitch/control_tower_test.go b/htlcswitch/control_tower_test.go deleted file mode 100644 index 2728e362..00000000 --- a/htlcswitch/control_tower_test.go +++ /dev/null @@ -1,351 +0,0 @@ -package htlcswitch - -import ( - "fmt" - "testing" - - "github.com/btcsuite/fastsha256" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/lnwire" -) - -func genHtlc() (*lnwire.UpdateAddHTLC, error) { - preimage, err := genPreimage() - if err != nil { - return nil, fmt.Errorf("unable to generate preimage: %v", err) - } - - rhash := fastsha256.Sum256(preimage[:]) - htlc := &lnwire.UpdateAddHTLC{ - PaymentHash: rhash, - Amount: 1, - } - - return htlc, nil -} - -type paymentControlTestCase func(*testing.T, bool) - -var paymentControlTests = []struct { - name string - strict bool - testcase paymentControlTestCase -}{ - { - name: "fail-strict", - strict: true, - testcase: testPaymentControlSwitchFail, - }, - { - name: "double-send-strict", - strict: true, - testcase: testPaymentControlSwitchDoubleSend, - }, - { - name: "double-pay-strict", - strict: true, - testcase: testPaymentControlSwitchDoublePay, - }, - { - name: "fail-not-strict", - strict: false, - testcase: testPaymentControlSwitchFail, - }, - { - name: "double-send-not-strict", - strict: false, - testcase: testPaymentControlSwitchDoubleSend, - }, - { - name: "double-pay-not-strict", - strict: false, - testcase: testPaymentControlSwitchDoublePay, - }, -} - -// TestPaymentControls runs a set of common tests against both the strict and -// non-strict payment control instances. This ensures that the two both behave -// identically when making the expected state-transitions of the stricter -// implementation. Behavioral differences in the strict and non-strict -// implementations are tested separately. -func TestPaymentControls(t *testing.T) { - for _, test := range paymentControlTests { - t.Run(test.name, func(t *testing.T) { - test.testcase(t, test.strict) - }) - } -} - -// testPaymentControlSwitchFail checks that payment status returns to Grounded -// status after failing, and that ClearForTakeoff allows another HTLC for the -// same payment hash. -func testPaymentControlSwitchFail(t *testing.T, strict bool) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(strict, db) - - htlc, err := genHtlc() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Sends base htlc message which initiate StatusInFlight. - if err := pControl.ClearForTakeoff(htlc); err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) - - // Fail the payment, which should moved it to Grounded. - if err := pControl.Fail(htlc.PaymentHash); err != nil { - t.Fatalf("unable to fail payment hash: %v", err) - } - - // Verify the status is indeed Grounded. - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) - - // Sends the htlc again, which should succeed since the prior payment - // failed. - if err := pControl.ClearForTakeoff(htlc); err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) - - // Verifies that status was changed to StatusCompleted. - if err := pControl.Success(htlc.PaymentHash); err != nil { - t.Fatalf("error shouldn't have been received, got: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) - - // Attempt a final payment, which should now fail since the prior - // payment succeed. - if err := pControl.ClearForTakeoff(htlc); err != ErrAlreadyPaid { - t.Fatalf("unable to send htlc message: %v", err) - } -} - -// testPaymentControlSwitchDoubleSend checks the ability of payment control to -// prevent double sending of htlc message, when message is in StatusInFlight. -func testPaymentControlSwitchDoubleSend(t *testing.T, strict bool) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(strict, db) - - htlc, err := genHtlc() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Sends base htlc message which initiate base status and move it to - // StatusInFlight and verifies that it was changed. - if err := pControl.ClearForTakeoff(htlc); err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) - - // Try to initiate double sending of htlc message with the same - // payment hash, should result in error indicating that payment has - // already been sent. - if err := pControl.ClearForTakeoff(htlc); err != ErrPaymentInFlight { - t.Fatalf("payment control wrong behaviour: " + - "double sending must trigger ErrPaymentInFlight error") - } -} - -// TestPaymentControlSwitchDoublePay checks the ability of payment control to -// prevent double payment. -func testPaymentControlSwitchDoublePay(t *testing.T, strict bool) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(strict, db) - - htlc, err := genHtlc() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Sends base htlc message which initiate StatusInFlight. - if err := pControl.ClearForTakeoff(htlc); err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - // Verify that payment is InFlight. - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) - - // Move payment to completed status, second payment should return error. - if err := pControl.Success(htlc.PaymentHash); err != nil { - t.Fatalf("error shouldn't have been received, got: %v", err) - } - - // Verify that payment is Completed. - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) - - if err := pControl.ClearForTakeoff(htlc); err != ErrAlreadyPaid { - t.Fatalf("payment control wrong behaviour:" + - " double payment must trigger ErrAlreadyPaid") - } -} - -// TestPaymentControlNonStrictSuccessesWithoutInFlight checks that a non-strict -// payment control will allow calls to Success when no payment is in flight. This -// is necessary to gracefully handle the case in which the switch already sent -// out a payment for a particular payment hash in a prior db version that didn't -// have payment statuses. -func TestPaymentControlNonStrictSuccessesWithoutInFlight(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(false, db) - - htlc, err := genHtlc() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - if err := pControl.Success(htlc.PaymentHash); err != nil { - t.Fatalf("unable to mark payment hash success: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) - - err = pControl.Success(htlc.PaymentHash) - if err != ErrPaymentAlreadyCompleted { - t.Fatalf("unable to remark payment hash failed: %v", err) - } -} - -// TestPaymentControlNonStrictFailsWithoutInFlight checks that a non-strict -// payment control will allow calls to Fail when no payment is in flight. This -// is necessary to gracefully handle the case in which the switch already sent -// out a payment for a particular payment hash in a prior db version that didn't -// have payment statuses. -func TestPaymentControlNonStrictFailsWithoutInFlight(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(false, db) - - htlc, err := genHtlc() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - if err := pControl.Fail(htlc.PaymentHash); err != nil { - t.Fatalf("unable to mark payment hash failed: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) - - err = pControl.Fail(htlc.PaymentHash) - if err != nil { - t.Fatalf("unable to remark payment hash failed: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) - - err = pControl.Success(htlc.PaymentHash) - if err != nil { - t.Fatalf("unable to remark payment hash success: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) - - err = pControl.Fail(htlc.PaymentHash) - if err != ErrPaymentAlreadyCompleted { - t.Fatalf("unable to remark payment hash failed: %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) -} - -// TestPaymentControlStrictSuccessesWithoutInFlight checks that a strict payment -// control will disallow calls to Success when no payment is in flight. -func TestPaymentControlStrictSuccessesWithoutInFlight(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(true, db) - - htlc, err := genHtlc() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - err = pControl.Success(htlc.PaymentHash) - if err != ErrPaymentNotInitiated { - t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) -} - -// TestPaymentControlStrictFailsWithoutInFlight checks that a strict payment -// control will disallow calls to Fail when no payment is in flight. -func TestPaymentControlStrictFailsWithoutInFlight(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(true, db) - - htlc, err := genHtlc() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - err = pControl.Fail(htlc.PaymentHash) - if err != ErrPaymentNotInitiated { - t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) - } - - assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) -} - -func assertPaymentStatus(t *testing.T, db *channeldb.DB, - hash [32]byte, expStatus channeldb.PaymentStatus) { - - t.Helper() - - pStatus, err := db.FetchPaymentStatus(hash) - if err != nil { - t.Fatalf("unable to fetch payment status: %v", err) - } - - if pStatus != expStatus { - t.Fatalf("payment status mismatch: expected %v, got %v", - expStatus, pStatus) - } -} diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 55184e40..646c6450 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -3889,8 +3889,7 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { } // With the invoice now added to Carol's registry, we'll send the - // payment. It should succeed w/o any issues as it has been crafted - // properly. + // payment. err = n.aliceServer.htlcSwitch.SendHTLC( n.firstBobChannelLink.ShortChanID(), pid, htlc, ) @@ -3905,6 +3904,16 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { t.Fatalf("unable to get payment result: %v", err) } + // Now, if we attempt to send the payment *again* it should be rejected + // as it's a duplicate request. + err = n.aliceServer.htlcSwitch.SendHTLC( + n.firstBobChannelLink.ShortChanID(), pid, htlc, + ) + if err != ErrPaymentIDAlreadyExists { + t.Fatalf("ErrPaymentIDAlreadyExists should have been "+ + "received got: %v", err) + } + select { case result, ok := <-resultChan: if !ok { @@ -3917,15 +3926,6 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) { case <-time.After(5 * time.Second): t.Fatalf("payment result did not arrive") } - - // Now, if we attempt to send the payment *again* it should be rejected - // as it's a duplicate request. - err = n.aliceServer.htlcSwitch.SendHTLC( - n.firstBobChannelLink.ShortChanID(), pid, htlc, - ) - if err != ErrAlreadyPaid { - t.Fatalf("ErrAlreadyPaid should have been received got: %v", err) - } } // TestChannelLinkAcceptOverpay tests that if we create an invoice for sender, diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index c850783e..45e2e2db 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -208,9 +208,6 @@ type Switch struct { pendingPayments map[uint64]*pendingPayment pendingMutex sync.RWMutex - // control provides verification of sending htlc mesages - control ControlTower - // circuits is storage for payment circuits which are used to // forward the settle/fail htlc updates back to the add htlc initiator. circuits CircuitMap @@ -290,7 +287,6 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { bestHeight: currentHeight, cfg: &cfg, circuits: circuitMap, - control: NewPaymentControl(false, cfg.DB), linkIndex: make(map[lnwire.ChannelID]ChannelLink), mailOrchestrator: newMailOrchestrator(), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), @@ -402,13 +398,6 @@ func (s *Switch) GetPaymentResult(paymentID uint64, func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, htlc *lnwire.UpdateAddHTLC) error { - // Before sending, double check that we don't already have 1) an - // in-flight payment to this payment hash, or 2) a complete payment for - // the same hash. - if err := s.control.ClearForTakeoff(htlc); err != nil { - return err - } - // Create payment and add to the map of payment in order later to be // able to retrieve it and return response to the user. payment := &pendingPayment{ @@ -439,10 +428,6 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, if err := s.forward(packet); err != nil { s.removePendingPayment(paymentID) - if err := s.control.Fail(htlc.PaymentHash); err != nil { - return err - } - return err } @@ -939,15 +924,6 @@ func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult, // We've received a settle update which means we can finalize the user // payment and return successful response. case *lnwire.UpdateFulfillHTLC: - // Persistently mark that a payment to this payment hash - // succeeded. This will prevent us from ever making another - // payment to this hash. - err := s.control.Success(paymentHash) - if err != nil && err != ErrPaymentAlreadyCompleted { - return nil, fmt.Errorf("Unable to mark completed "+ - "payment %x: %v", paymentHash, err) - } - return &PaymentResult{ Preimage: htlc.PaymentPreimage, }, nil @@ -955,14 +931,6 @@ func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult, // We've received a fail update which means we can finalize the // user payment and return fail response. case *lnwire.UpdateFailHTLC: - // Persistently mark that a payment to this payment hash - // failed. This will permit us to make another attempt at a - // successful payment. - err := s.control.Fail(paymentHash) - if err != nil && err != ErrPaymentAlreadyCompleted { - return nil, fmt.Errorf("Unable to ground payment "+ - "%x: %v", paymentHash, err) - } paymentErr := s.parseFailedPayment( deobfuscator, paymentID, paymentHash, n.unencrypted, n.isResolution, htlc, diff --git a/routing/missioncontrol.go b/routing/missioncontrol.go index 0a8a48a1..9cbaa911 100644 --- a/routing/missioncontrol.go +++ b/routing/missioncontrol.go @@ -14,7 +14,7 @@ import ( const ( // vertexDecay is the decay period of colored vertexes added to - // missionControl. Once vertexDecay passes after an entry has been + // MissionControl. Once vertexDecay passes after an entry has been // added to the prune view, it is garbage collected. This value is // larger than edgeDecay as an edge failure typical indicates an // unbalanced channel, while a vertex failure indicates a node is not @@ -22,7 +22,7 @@ const ( vertexDecay = time.Duration(time.Minute * 5) // edgeDecay is the decay period of colored edges added to - // missionControl. Once edgeDecay passed after an entry has been added, + // MissionControl. Once edgeDecay passed after an entry has been added, // it is garbage collected. This value is smaller than vertexDecay as // an edge related failure during payment sending typically indicates // that a channel was unbalanced, a condition which may quickly change. @@ -31,11 +31,11 @@ const ( edgeDecay = time.Duration(time.Second * 5) ) -// missionControl contains state which summarizes the past attempts of HTLC +// MissionControl contains state which summarizes the past attempts of HTLC // routing by external callers when sending payments throughout the network. -// missionControl remembers the outcome of these past routing attempts (success +// MissionControl remembers the outcome of these past routing attempts (success // and failure), and is able to provide hints/guidance to future HTLC routing -// attempts. missionControl maintains a decaying network view of the +// attempts. MissionControl maintains a decaying network view of the // edges/vertexes that should be marked as "pruned" during path finding. This // graph view acts as a shared memory during HTLC payment routing attempts. // With each execution, if an error is encountered, based on the type of error @@ -43,16 +43,16 @@ const ( // to the view. Later sending attempts will then query the view for all the // vertexes/edges that should be ignored. Items in the view decay after a set // period of time, allowing the view to be dynamic w.r.t network changes. -type missionControl struct { +type MissionControl struct { // failedEdges maps a short channel ID to be pruned, to the time that // it was added to the prune view. Edges are added to this map if a - // caller reports to missionControl a failure localized to that edge + // caller reports to MissionControl a failure localized to that edge // when sending a payment. failedEdges map[EdgeLocator]time.Time // failedVertexes maps a node's public key that should be pruned, to // the time that it was added to the prune view. Vertexes are added to - // this map if a caller reports to missionControl a failure localized + // this map if a caller reports to MissionControl a failure localized // to that particular vertex. failedVertexes map[route.Vertex]time.Time @@ -70,13 +70,17 @@ type missionControl struct { // TODO(roasbeef): also add favorable metrics for nodes } -// newMissionControl returns a new instance of missionControl. +// A compile time assertion to ensure MissionControl meets the +// PaymentSessionSource interface. +var _ PaymentSessionSource = (*MissionControl)(nil) + +// NewMissionControl returns a new instance of MissionControl. // // TODO(roasbeef): persist memory -func newMissionControl(g *channeldb.ChannelGraph, selfNode *channeldb.LightningNode, - qb func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi) *missionControl { +func NewMissionControl(g *channeldb.ChannelGraph, selfNode *channeldb.LightningNode, + qb func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi) *MissionControl { - return &missionControl{ + return &MissionControl{ failedEdges: make(map[EdgeLocator]time.Time), failedVertexes: make(map[route.Vertex]time.Time), selfNode: selfNode, @@ -96,12 +100,12 @@ type graphPruneView struct { vertexes map[route.Vertex]struct{} } -// GraphPruneView returns a new graphPruneView instance which is to be +// graphPruneView returns a new graphPruneView instance which is to be // consulted during path finding. If a vertex/edge is found within the returned // prune view, it is to be ignored as a goroutine has had issues routing // through it successfully. Within this method the main view of the -// missionControl is garbage collected as entries are detected to be "stale". -func (m *missionControl) GraphPruneView() graphPruneView { +// MissionControl is garbage collected as entries are detected to be "stale". +func (m *MissionControl) graphPruneView() graphPruneView { // First, we'll grab the current time, this value will be used to // determine if an entry is stale or not. now := time.Now() @@ -154,10 +158,10 @@ func (m *missionControl) GraphPruneView() graphPruneView { // view from Mission Control. An optional set of routing hints can be provided // in order to populate additional edges to explore when finding a path to the // payment's destination. -func (m *missionControl) NewPaymentSession(routeHints [][]zpay32.HopHint, - target route.Vertex) (*paymentSession, error) { +func (m *MissionControl) NewPaymentSession(routeHints [][]zpay32.HopHint, + target route.Vertex) (PaymentSession, error) { - viewSnapshot := m.GraphPruneView() + viewSnapshot := m.graphPruneView() edges := make(map[route.Vertex][]*channeldb.ChannelEdgePolicy) @@ -233,15 +237,28 @@ func (m *missionControl) NewPaymentSession(routeHints [][]zpay32.HopHint, // NewPaymentSessionForRoute creates a new paymentSession instance that is just // used for failure reporting to missioncontrol. -func (m *missionControl) NewPaymentSessionForRoute(preBuiltRoute *route.Route) *paymentSession { +func (m *MissionControl) NewPaymentSessionForRoute(preBuiltRoute *route.Route) PaymentSession { return &paymentSession{ - pruneViewSnapshot: m.GraphPruneView(), + pruneViewSnapshot: m.graphPruneView(), errFailedPolicyChans: make(map[EdgeLocator]struct{}), mc: m, preBuiltRoute: preBuiltRoute, } } +// NewPaymentSessionEmpty creates a new paymentSession instance that is empty, +// and will be exhausted immediately. Used for failure reporting to +// missioncontrol for resumed payment we don't want to make more attempts for. +func (m *MissionControl) NewPaymentSessionEmpty() PaymentSession { + return &paymentSession{ + pruneViewSnapshot: m.graphPruneView(), + errFailedPolicyChans: make(map[EdgeLocator]struct{}), + mc: m, + preBuiltRoute: &route.Route{}, + preBuiltRouteTried: true, + } +} + // generateBandwidthHints is a helper function that's utilized the main // findPath function in order to obtain hints from the lower layer w.r.t to the // available bandwidth of edges on the network. Currently, we'll only obtain @@ -277,9 +294,9 @@ func generateBandwidthHints(sourceNode *channeldb.LightningNode, return bandwidthHints, nil } -// ResetHistory resets the history of missionControl returning it to a state as +// ResetHistory resets the history of MissionControl returning it to a state as // if no payment attempts have been made. -func (m *missionControl) ResetHistory() { +func (m *MissionControl) ResetHistory() { m.Lock() m.failedEdges = make(map[EdgeLocator]time.Time) m.failedVertexes = make(map[route.Vertex]time.Time) diff --git a/routing/mock_test.go b/routing/mock_test.go index 03aa2923..6d98dd6e 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -1,8 +1,15 @@ package routing import ( + "fmt" + "sync" + + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/zpay32" ) type mockPaymentAttemptDispatcher struct { @@ -62,3 +69,219 @@ func (m *mockPaymentAttemptDispatcher) setPaymentResult( m.onPayment = f } + +type mockPaymentSessionSource struct { + routes []*route.Route +} + +var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil) + +func (m *mockPaymentSessionSource) NewPaymentSession(routeHints [][]zpay32.HopHint, + target route.Vertex) (PaymentSession, error) { + + return &mockPaymentSession{m.routes}, nil +} + +func (m *mockPaymentSessionSource) NewPaymentSessionForRoute( + preBuiltRoute *route.Route) PaymentSession { + return nil +} + +func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession { + return &mockPaymentSession{} +} + +type mockPaymentSession struct { + routes []*route.Route +} + +var _ PaymentSession = (*mockPaymentSession)(nil) + +func (m *mockPaymentSession) RequestRoute(payment *LightningPayment, + height uint32, finalCltvDelta uint16) (*route.Route, error) { + + if len(m.routes) == 0 { + return nil, fmt.Errorf("no routes") + } + + r := m.routes[0] + m.routes = m.routes[1:] + + return r, nil +} + +func (m *mockPaymentSession) ReportVertexFailure(v route.Vertex) {} + +func (m *mockPaymentSession) ReportEdgeFailure(e *EdgeLocator) {} + +func (m *mockPaymentSession) ReportEdgePolicyFailure(errSource route.Vertex, failedEdge *EdgeLocator) { +} + +type mockPayer struct { + sendResult chan error + paymentResultErr chan error + paymentResult chan *htlcswitch.PaymentResult + quit chan struct{} +} + +var _ PaymentAttemptDispatcher = (*mockPayer)(nil) + +func (m *mockPayer) SendHTLC(_ lnwire.ShortChannelID, + paymentID uint64, + _ *lnwire.UpdateAddHTLC) error { + + select { + case res := <-m.sendResult: + return res + case <-m.quit: + return fmt.Errorf("test quitting") + } + +} + +func (m *mockPayer) GetPaymentResult(paymentID uint64, _ htlcswitch.ErrorDecrypter) ( + <-chan *htlcswitch.PaymentResult, error) { + + select { + case res := <-m.paymentResult: + resChan := make(chan *htlcswitch.PaymentResult, 1) + resChan <- res + return resChan, nil + case err := <-m.paymentResultErr: + return nil, err + case <-m.quit: + return nil, fmt.Errorf("test quitting") + } +} + +type initArgs struct { + c *channeldb.PaymentCreationInfo +} + +type registerArgs struct { + a *channeldb.PaymentAttemptInfo +} + +type successArgs struct { + preimg lntypes.Preimage +} + +type failArgs struct { + reason channeldb.FailureReason +} + +type mockControlTower struct { + inflights map[lntypes.Hash]channeldb.InFlightPayment + successful map[lntypes.Hash]struct{} + + init chan initArgs + register chan registerArgs + success chan successArgs + fail chan failArgs + fetchInFlight chan struct{} + + sync.Mutex +} + +var _ channeldb.ControlTower = (*mockControlTower)(nil) + +func makeMockControlTower() *mockControlTower { + return &mockControlTower{ + inflights: make(map[lntypes.Hash]channeldb.InFlightPayment), + successful: make(map[lntypes.Hash]struct{}), + } +} + +func (m *mockControlTower) InitPayment(phash lntypes.Hash, + c *channeldb.PaymentCreationInfo) error { + + m.Lock() + defer m.Unlock() + + if m.init != nil { + m.init <- initArgs{c} + } + + if _, ok := m.successful[phash]; ok { + return fmt.Errorf("already successful") + } + + _, ok := m.inflights[phash] + if ok { + return fmt.Errorf("in flight") + } + + m.inflights[phash] = channeldb.InFlightPayment{ + Info: c, + } + + return nil +} + +func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash, + a *channeldb.PaymentAttemptInfo) error { + + m.Lock() + defer m.Unlock() + + if m.register != nil { + m.register <- registerArgs{a} + } + + p, ok := m.inflights[phash] + if !ok { + return fmt.Errorf("not in flight") + } + + p.Attempt = a + m.inflights[phash] = p + + return nil +} + +func (m *mockControlTower) Success(phash lntypes.Hash, + preimg lntypes.Preimage) error { + + m.Lock() + defer m.Unlock() + + if m.success != nil { + m.success <- successArgs{preimg} + } + + delete(m.inflights, phash) + m.successful[phash] = struct{}{} + return nil +} + +func (m *mockControlTower) Fail(phash lntypes.Hash, + reason channeldb.FailureReason) error { + + m.Lock() + defer m.Unlock() + + if m.fail != nil { + m.fail <- failArgs{reason} + } + + delete(m.inflights, phash) + return nil +} + +func (m *mockControlTower) FetchInFlightPayments() ( + []*channeldb.InFlightPayment, error) { + + m.Lock() + defer m.Unlock() + + if m.fetchInFlight != nil { + m.fetchInFlight <- struct{}{} + } + + var fl []*channeldb.InFlightPayment + for _, ifl := range m.inflights { + fl = append(fl, &ifl) + } + + return fl, nil +} diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go new file mode 100644 index 00000000..38883ca8 --- /dev/null +++ b/routing/payment_lifecycle.go @@ -0,0 +1,354 @@ +package routing + +import ( + "fmt" + "time" + + "github.com/davecgh/go-spew/spew" + sphinx "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// paymentLifecycle holds all information about the current state of a payment +// needed to resume if from any point. +type paymentLifecycle struct { + router *ChannelRouter + payment *LightningPayment + paySession PaymentSession + timeoutChan <-chan time.Time + currentHeight int32 + finalCLTVDelta uint16 + attempt *channeldb.PaymentAttemptInfo + circuit *sphinx.Circuit + lastError error +} + +// resumePayment resumes the paymentLifecycle from the current state. +func (p *paymentLifecycle) resumePayment() ([32]byte, *route.Route, error) { + // We'll continue until either our payment succeeds, or we encounter a + // critical error during path finding. + for { + + // If this payment had no existing payment attempt, we create + // and send one now. + if p.attempt == nil { + firstHop, htlcAdd, err := p.createNewPaymentAttempt() + if err != nil { + return [32]byte{}, nil, err + } + + // Now that the attempt is created and checkpointed to + // the DB, we send it. + sendErr := p.sendPaymentAttempt(firstHop, htlcAdd) + if sendErr != nil { + // We must inspect the error to know whether it + // was critical or not, to decide whether we + // should continue trying. + err := p.handleSendError(sendErr) + if err != nil { + return [32]byte{}, nil, err + } + + // Error was handled successfully, reset the + // attempt to indicate we want to make a new + // attempt. + p.attempt = nil + continue + } + } else { + // If this was a resumed attempt, we must regenerate the + // circuit. + _, c, err := generateSphinxPacket( + &p.attempt.Route, p.payment.PaymentHash[:], + p.attempt.SessionKey, + ) + if err != nil { + return [32]byte{}, nil, err + } + p.circuit = c + } + + // Using the created circuit, initialize the error decrypter so we can + // parse+decode any failures incurred by this payment within the + // switch. + errorDecryptor := &htlcswitch.SphinxErrorDecrypter{ + OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(p.circuit), + } + + // Now ask the switch to return the result of the payment when + // available. + resultChan, err := p.router.cfg.Payer.GetPaymentResult( + p.attempt.PaymentID, errorDecryptor, + ) + switch { + + // If this payment ID is unknown to the Switch, it means it was + // never checkpointed and forwarded by the switch before a + // restart. In this case we can safely send a new payment + // attempt, and wait for its result to be available. + case err == htlcswitch.ErrPaymentIDNotFound: + log.Debugf("Payment ID %v for hash %x not found in "+ + "the Switch, retrying.", p.attempt.PaymentID, + p.payment.PaymentHash) + + // Reset the attempt to indicate we want to make a new + // attempt. + p.attempt = nil + continue + + // A critical, unexpected error was encountered. + case err != nil: + log.Errorf("Failed getting result for paymentID %d "+ + "from switch: %v", p.attempt.PaymentID, err) + + return [32]byte{}, nil, err + } + + // The switch knows about this payment, we'll wait for a result + // to be available. + var ( + result *htlcswitch.PaymentResult + ok bool + ) + + select { + case result, ok = <-resultChan: + if !ok { + return [32]byte{}, nil, htlcswitch.ErrSwitchExiting + } + + case <-p.router.quit: + return [32]byte{}, nil, ErrRouterShuttingDown + } + + // In case of a payment failure, we use the error to decide + // whether we should retry. + if result.Error != nil { + log.Errorf("Attempt to send payment %x failed: %v", + p.payment.PaymentHash, result.Error) + + // We must inspect the error to know whether it was + // critical or not, to decide whether we should + // continue trying. + if err := p.handleSendError(result.Error); err != nil { + return [32]byte{}, nil, err + } + + // Error was handled successfully, reset the attempt to + // indicate we want to make a new attempt. + p.attempt = nil + continue + } + + // We successfully got a payment result back from the switch. + log.Debugf("Payment %x succeeded with pid=%v", + p.payment.PaymentHash, p.attempt.PaymentID) + + // In case of success we atomically store the db payment and + // move the payment to the success state. + err = p.router.cfg.Control.Success(p.payment.PaymentHash, result.Preimage) + if err != nil { + log.Errorf("Unable to succeed payment "+ + "attempt: %v", err) + return [32]byte{}, nil, err + } + + // Terminal state, return the preimage and the route + // taken. + return result.Preimage, &p.attempt.Route, nil + } + +} + +// createNewPaymentAttempt creates and stores a new payment attempt to the +// database. +func (p *paymentLifecycle) createNewPaymentAttempt() (lnwire.ShortChannelID, + *lnwire.UpdateAddHTLC, error) { + + // Before we attempt this next payment, we'll check to see if + // either we've gone past the payment attempt timeout, or the + // router is exiting. In either case, we'll stop this payment + // attempt short. + select { + case <-p.timeoutChan: + // Mark the payment as failed because of the + // timeout. + err := p.router.cfg.Control.Fail( + p.payment.PaymentHash, channeldb.FailureReasonTimeout, + ) + if err != nil { + return lnwire.ShortChannelID{}, nil, err + } + + errStr := fmt.Sprintf("payment attempt not completed " + + "before timeout") + + return lnwire.ShortChannelID{}, nil, + newErr(ErrPaymentAttemptTimeout, errStr) + + case <-p.router.quit: + // The payment will be resumed from the current state + // after restart. + return lnwire.ShortChannelID{}, nil, ErrRouterShuttingDown + + default: + // Fall through if we haven't hit our time limit, or + // are expiring. + } + + // Create a new payment attempt from the given payment session. + route, err := p.paySession.RequestRoute( + p.payment, uint32(p.currentHeight), p.finalCLTVDelta, + ) + if err != nil { + // If we're unable to successfully make a payment using + // any of the routes we've found, then mark the payment + // as permanently failed. + saveErr := p.router.cfg.Control.Fail( + p.payment.PaymentHash, channeldb.FailureReasonNoRoute, + ) + if saveErr != nil { + return lnwire.ShortChannelID{}, nil, saveErr + } + + // If there was an error already recorded for this + // payment, we'll return that. + if p.lastError != nil { + return lnwire.ShortChannelID{}, nil, + fmt.Errorf("unable to route payment to "+ + "destination: %v", p.lastError) + } + + // Terminal state, return. + return lnwire.ShortChannelID{}, nil, err + } + + // Generate a new key to be used for this attempt. + sessionKey, err := generateNewSessionKey() + if err != nil { + return lnwire.ShortChannelID{}, nil, err + } + + // Generate the raw encoded sphinx packet to be included along + // with the htlcAdd message that we send directly to the + // switch. + onionBlob, c, err := generateSphinxPacket( + route, p.payment.PaymentHash[:], sessionKey, + ) + if err != nil { + return lnwire.ShortChannelID{}, nil, err + } + + // Update our cached circuit with the newly generated + // one. + p.circuit = c + + // Craft an HTLC packet to send to the layer 2 switch. The + // metadata within this packet will be used to route the + // payment through the network, starting with the first-hop. + htlcAdd := &lnwire.UpdateAddHTLC{ + Amount: route.TotalAmount, + Expiry: route.TotalTimeLock, + PaymentHash: p.payment.PaymentHash, + } + copy(htlcAdd.OnionBlob[:], onionBlob) + + // Attempt to send this payment through the network to complete + // the payment. If this attempt fails, then we'll continue on + // to the next available route. + firstHop := lnwire.NewShortChanIDFromInt( + route.Hops[0].ChannelID, + ) + + // We generate a new, unique payment ID that we will use for + // this HTLC. + paymentID, err := p.router.cfg.NextPaymentID() + if err != nil { + return lnwire.ShortChannelID{}, nil, err + } + + // We now have all the information needed to populate + // the current attempt information. + p.attempt = &channeldb.PaymentAttemptInfo{ + PaymentID: paymentID, + SessionKey: sessionKey, + Route: *route, + } + + // Before sending this HTLC to the switch, we checkpoint the + // fresh paymentID and route to the DB. This lets us know on + // startup the ID of the payment that we attempted to send, + // such that we can query the Switch for its whereabouts. The + // route is needed to handle the result when it eventually + // comes back. + err = p.router.cfg.Control.RegisterAttempt(p.payment.PaymentHash, p.attempt) + if err != nil { + return lnwire.ShortChannelID{}, nil, err + } + + return firstHop, htlcAdd, nil +} + +// sendPaymentAttempt attempts to send the current attempt to the switch. +func (p *paymentLifecycle) sendPaymentAttempt(firstHop lnwire.ShortChannelID, + htlcAdd *lnwire.UpdateAddHTLC) error { + + log.Tracef("Attempting to send payment %x (pid=%v), "+ + "using route: %v", p.payment.PaymentHash, p.attempt.PaymentID, + newLogClosure(func() string { + return spew.Sdump(p.attempt.Route) + }), + ) + + // Send it to the Switch. When this method returns we assume + // the Switch successfully has persisted the payment attempt, + // such that we can resume waiting for the result after a + // restart. + err := p.router.cfg.Payer.SendHTLC( + firstHop, p.attempt.PaymentID, htlcAdd, + ) + if err != nil { + log.Errorf("Failed sending attempt %d for payment "+ + "%x to switch: %v", p.attempt.PaymentID, + p.payment.PaymentHash, err) + return err + } + + log.Debugf("Payment %x (pid=%v) successfully sent to switch", + p.payment.PaymentHash, p.attempt.PaymentID) + + return nil +} + +// handleSendError inspects the given error from the Switch and determines +// whether we should make another payment attempt. +func (p *paymentLifecycle) handleSendError(sendErr error) error { + finalOutcome := p.router.processSendError( + p.paySession, &p.attempt.Route, sendErr, + ) + + if finalOutcome { + log.Errorf("Payment %x failed with final outcome: %v", + p.payment.PaymentHash, sendErr) + + // Mark the payment failed with no route. + // TODO(halseth): make payment codes for the actual reason we + // don't continue path finding. + err := p.router.cfg.Control.Fail( + p.payment.PaymentHash, channeldb.FailureReasonNoRoute, + ) + if err != nil { + return err + } + + // Terminal state, return the error we encountered. + return sendErr + } + + // We get ready to make another payment attempt. + p.lastError = sendErr + return nil +} diff --git a/routing/payment_session.go b/routing/payment_session.go index 459a8de2..28047f53 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -9,9 +9,39 @@ import ( "github.com/lightningnetwork/lnd/routing/route" ) +// PaymentSession is used during SendPayment attempts to provide routes to +// attempt. It also defines methods to give the PaymentSession additional +// information learned during the previous attempts. +type PaymentSession interface { + // RequestRoute returns the next route to attempt for routing the + // specified HTLC payment to the target node. + RequestRoute(payment *LightningPayment, + height uint32, finalCltvDelta uint16) (*route.Route, error) + + // ReportVertexFailure reports to the PaymentSession that the passsed + // vertex failed to route the previous payment attempt. The + // PaymentSession will use this information to produce a better next + // route. + ReportVertexFailure(v route.Vertex) + + // ReportEdgeFailure reports to the PaymentSession that the passed + // channel failed to route the previous payment attempt. The + // PaymentSession will use this information to produce a better next + // route. + ReportEdgeFailure(e *EdgeLocator) + + // ReportEdgePolicyFailure reports to the PaymentSession that we + // received a failure message that relates to a channel policy. For + // these types of failures, the PaymentSession can decide whether to to + // keep the edge included in the next attempted route. The + // PaymentSession will use this information to produce a better next + // route. + ReportEdgePolicyFailure(errSource route.Vertex, failedEdge *EdgeLocator) +} + // paymentSession is used during an HTLC routings session to prune the local // chain view in response to failures, and also report those failures back to -// missionControl. The snapshot copied for this session will only ever grow, +// MissionControl. The snapshot copied for this session will only ever grow, // and will now be pruned after a decay like the main view within mission // control. We do this as we want to avoid the case where we continually try a // bad edge or route multiple times in a session. This can lead to an infinite @@ -30,7 +60,7 @@ type paymentSession struct { // require pruning, but any subsequent ones do. errFailedPolicyChans map[EdgeLocator]struct{} - mc *missionControl + mc *MissionControl preBuiltRoute *route.Route preBuiltRouteTried bool @@ -38,11 +68,17 @@ type paymentSession struct { pathFinder pathFinder } +// A compile time assertion to ensure paymentSession meets the PaymentSession +// interface. +var _ PaymentSession = (*paymentSession)(nil) + // ReportVertexFailure adds a vertex to the graph prune view after a client // reports a routing failure localized to the vertex. The time the vertex was // added is noted, as it'll be pruned from the shared view after a period of // vertexDecay. However, the vertex will remain pruned for the *local* session. // This ensures we don't retry this vertex during the payment attempt. +// +// NOTE: Part of the PaymentSession interface. func (p *paymentSession) ReportVertexFailure(v route.Vertex) { log.Debugf("Reporting vertex %v failure to Mission Control", v) @@ -57,13 +93,15 @@ func (p *paymentSession) ReportVertexFailure(v route.Vertex) { p.mc.Unlock() } -// ReportChannelFailure adds a channel to the graph prune view. The time the +// ReportEdgeFailure adds a channel to the graph prune view. The time the // channel was added is noted, as it'll be pruned from the global view after a // period of edgeDecay. However, the edge will remain pruned for the duration // of the *local* session. This ensures that we don't flap by continually // retrying an edge after its pruning has expired. // // TODO(roasbeef): also add value attempted to send and capacity of channel +// +// NOTE: Part of the PaymentSession interface. func (p *paymentSession) ReportEdgeFailure(e *EdgeLocator) { log.Debugf("Reporting edge %v failure to Mission Control", e) @@ -78,12 +116,14 @@ func (p *paymentSession) ReportEdgeFailure(e *EdgeLocator) { p.mc.Unlock() } -// ReportChannelPolicyFailure handles a failure message that relates to a +// ReportEdgePolicyFailure handles a failure message that relates to a // channel policy. For these types of failures, the policy is updated and we // want to keep it included during path finding. This function does mark the // edge as 'policy failed once'. The next time it fails, the whole node will be // pruned. This is to prevent nodes from keeping us busy by continuously sending // new channel updates. +// +// NOTE: Part of the PaymentSession interface. func (p *paymentSession) ReportEdgePolicyFailure( errSource route.Vertex, failedEdge *EdgeLocator) { @@ -111,6 +151,7 @@ func (p *paymentSession) ReportEdgePolicyFailure( // will be explored, which feeds into the recommendations made for routing. // // NOTE: This function is safe for concurrent access. +// NOTE: Part of the PaymentSession interface. func (p *paymentSession) RequestRoute(payment *LightningPayment, height uint32, finalCltvDelta uint16) (*route.Route, error) { @@ -151,7 +192,7 @@ func (p *paymentSession) RequestRoute(payment *LightningPayment, // Taking into account this prune view, we'll attempt to locate a path // to our destination, respecting the recommendations from - // missionControl. + // MissionControl. path, err := p.pathFinder( &graphParams{ graph: p.mc.graph, diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index 0f3133bc..f730bc57 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -33,7 +33,7 @@ func TestRequestRoute(t *testing.T) { } session := &paymentSession{ - mc: &missionControl{ + mc: &MissionControl{ selfNode: &channeldb.LightningNode{}, }, pruneViewSnapshot: graphPruneView{}, diff --git a/routing/router.go b/routing/router.go index 75fe831d..6e85a096 100644 --- a/routing/router.go +++ b/routing/router.go @@ -147,6 +147,28 @@ type PaymentAttemptDispatcher interface { <-chan *htlcswitch.PaymentResult, error) } +// PaymentSessionSource is an interface that defines a source for the router to +// retrive new payment sessions. +type PaymentSessionSource interface { + // NewPaymentSession creates a new payment session that will produce + // routes to the given target. An optional set of routing hints can be + // provided in order to populate additional edges to explore when + // finding a path to the payment's destination. + NewPaymentSession(routeHints [][]zpay32.HopHint, + target route.Vertex) (PaymentSession, error) + + // NewPaymentSessionForRoute creates a new paymentSession instance that + // is just used for failure reporting to missioncontrol, and will only + // attempt the given route. + NewPaymentSessionForRoute(preBuiltRoute *route.Route) PaymentSession + + // NewPaymentSessionEmpty creates a new paymentSession instance that is + // empty, and will be exhausted immediately. Used for failure reporting + // to missioncontrol for resumed payment we don't want to make more + // attempts for. + NewPaymentSessionEmpty() PaymentSession +} + // FeeSchema is the set fee configuration for a Lightning Node on the network. // Using the coefficients described within the schema, the required fee to // forward outgoing payments can be derived. @@ -199,6 +221,19 @@ type Config struct { // their results. Payer PaymentAttemptDispatcher + // Control keeps track of the status of ongoing payments, ensuring we + // can properly resume them across restarts. + Control channeldb.ControlTower + + // MissionControl is a shared memory of sorts that executions of + // payment path finding use in order to remember which vertexes/edges + // were pruned from prior attempts. During SendPayment execution, + // errors sent by nodes are mapped into a vertex or edge to be pruned. + // Each run will then take into account this set of pruned + // vertexes/edges to reduce route failure and pass on graph information + // gained to the next execution. + MissionControl PaymentSessionSource + // ChannelPruneExpiry is the duration used to determine if a channel // should be pruned or not. If the delta between now and when the // channel was last updated is greater than ChannelPruneExpiry, then @@ -338,15 +373,6 @@ type ChannelRouter struct { // existing client. ntfnClientUpdates chan *topologyClientUpdate - // missionControl is a shared memory of sorts that executions of - // payment path finding use in order to remember which vertexes/edges - // were pruned from prior attempts. During SendPayment execution, - // errors sent by nodes are mapped into a vertex or edge to be pruned. - // Each run will then take into account this set of pruned - // vertexes/edges to reduce route failure and pass on graph information - // gained to the next execution. - missionControl *missionControl - // channelEdgeMtx is a mutex we use to make sure we process only one // ChannelEdgePolicy at a time for a given channelID, to ensure // consistency between the various database accesses. @@ -388,10 +414,6 @@ func New(cfg Config) (*ChannelRouter, error) { quit: make(chan struct{}), } - r.missionControl = newMissionControl( - cfg.Graph, selfNode, cfg.QueryBandwidth, - ) - return r, nil } @@ -488,6 +510,40 @@ func (r *ChannelRouter) Start() error { } } + // If any payments are still in flight, we resume, to make sure their + // results are properly handled. + payments, err := r.cfg.Control.FetchInFlightPayments() + if err != nil { + return err + } + + for _, payment := range payments { + log.Infof("Resuming payment with hash %v", payment.Info.PaymentHash) + r.wg.Add(1) + go func(payment *channeldb.InFlightPayment) { + defer r.wg.Done() + + // We create a dummy, empty payment session such that + // we won't make another payment attempt when the + // result for the in-flight attempt is received. + paySession := r.cfg.MissionControl.NewPaymentSessionEmpty() + + lPayment := &LightningPayment{ + PaymentHash: payment.Info.PaymentHash, + } + + _, _, err = r.sendPayment(payment.Attempt, lPayment, paySession) + if err != nil { + log.Errorf("Resuming payment with hash %v "+ + "failed: %v.", payment.Info.PaymentHash, err) + return + } + + log.Infof("Resumed payment with hash %v completed.", + payment.Info.PaymentHash) + }(payment) + } + r.wg.Add(1) go r.networkHandler() @@ -1513,6 +1569,7 @@ type LightningPayment struct { // when we should should abandon the payment attempt after consecutive // payment failure. This prevents us from attempting to send a payment // indefinitely. + // TODO(halseth): make wallclock time to allow resume after startup. PayAttemptTimeout time.Duration // RouteHints represents the different routing hints that can be used to @@ -1543,14 +1600,30 @@ func (r *ChannelRouter) SendPayment(payment *LightningPayment) ([32]byte, *route // Before starting the HTLC routing attempt, we'll create a fresh // payment session which will report our errors back to mission // control. - paySession, err := r.missionControl.NewPaymentSession( + paySession, err := r.cfg.MissionControl.NewPaymentSession( payment.RouteHints, payment.Target, ) if err != nil { return [32]byte{}, nil, err } - return r.sendPayment(payment, paySession) + // Record this payment hash with the ControlTower, ensuring it is not + // already in-flight. + info := &channeldb.PaymentCreationInfo{ + PaymentHash: payment.PaymentHash, + Value: payment.Amount, + CreationDate: time.Now(), + PaymentRequest: nil, + } + + err = r.cfg.Control.InitPayment(payment.PaymentHash, info) + if err != nil { + return [32]byte{}, nil, err + } + + // Since this is the first time this payment is being made, we pass nil + // for the existing attempt. + return r.sendPayment(nil, payment, paySession) } // SendToRoute attempts to send a payment with the given hash through the @@ -1560,7 +1633,7 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, route *route.Route) ( lntypes.Preimage, error) { // Create a payment session for just this route. - paySession := r.missionControl.NewPaymentSessionForRoute(route) + paySession := r.cfg.MissionControl.NewPaymentSessionForRoute(route) // Create a (mostly) dummy payment, as the created payment session is // not going to do path finding. @@ -1568,8 +1641,23 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, route *route.Route) ( PaymentHash: hash, } - preimage, _, err := r.sendPayment(payment, paySession) + // Record this payment hash with the ControlTower, ensuring it is not + // already in-flight. + info := &channeldb.PaymentCreationInfo{ + PaymentHash: payment.PaymentHash, + Value: payment.Amount, + CreationDate: time.Now(), + PaymentRequest: nil, + } + err := r.cfg.Control.InitPayment(payment.PaymentHash, info) + if err != nil { + return [32]byte{}, err + } + + // Since this is the first time this payment is being made, we pass nil + // for the existing attempt. + preimage, _, err := r.sendPayment(nil, payment, paySession) return preimage, err } @@ -1580,8 +1668,19 @@ func (r *ChannelRouter) SendToRoute(hash lntypes.Hash, route *route.Route) ( // will be returned which describes the path the successful payment traversed // within the network to reach the destination. Additionally, the payment // preimage will also be returned. -func (r *ChannelRouter) sendPayment(payment *LightningPayment, - paySession *paymentSession) ([32]byte, *route.Route, error) { +// +// The existing attempt argument should be set to nil if this is a payment that +// haven't had any payment attempt sent to the switch yet. If it has had an +// attempt already, it should be passed such that the result can be retrieved. +// +// This method relies on the ControlTower's internal payment state machine to +// carry out its execution. After restarts it is safe, and assumed, that the +// router will call this method for every payment still in-flight according to +// the ControlTower. +func (r *ChannelRouter) sendPayment( + existingAttempt *channeldb.PaymentAttemptInfo, + payment *LightningPayment, paySession PaymentSession) ( + [32]byte, *route.Route, error) { log.Tracef("Dispatching route for lightning payment: %v", newLogClosure(func() string { @@ -1617,171 +1716,22 @@ func (r *ChannelRouter) sendPayment(payment *LightningPayment, timeoutChan := time.After(payAttemptTimeout) - // We'll continue until either our payment succeeds, or we encounter a - // critical error during path finding. - var lastError error - for { - // Before we attempt this next payment, we'll check to see if - // either we've gone past the payment attempt timeout, or the - // router is exiting. In either case, we'll stop this payment - // attempt short. - select { - case <-timeoutChan: - errStr := fmt.Sprintf("payment attempt not completed "+ - "before timeout of %v", payAttemptTimeout) - - return [32]byte{}, nil, newErr( - ErrPaymentAttemptTimeout, errStr, - ) - - case <-r.quit: - return [32]byte{}, nil, ErrRouterShuttingDown - - default: - // Fall through if we haven't hit our time limit, or - // are expiring. - } - - route, err := paySession.RequestRoute( - payment, uint32(currentHeight), finalCLTVDelta, - ) - if err != nil { - // If we're unable to successfully make a payment using - // any of the routes we've found, then return an error. - if lastError != nil { - return [32]byte{}, nil, fmt.Errorf("unable to "+ - "route payment to destination: %v", - lastError) - } - - return [32]byte{}, nil, err - } - - // Send payment attempt. It will return a final boolean - // indicating if more attempts are needed. - preimage, final, err := r.sendPaymentAttempt( - paySession, route, payment.PaymentHash, - ) - if final { - return preimage, route, err - } - - lastError = err - } -} - -// sendPaymentAttempt tries to send the payment via the specified route. If -// successful, it returns the obtained preimage. If an error occurs, the last -// bool parameter indicates whether this is a final outcome or more attempts -// should be made. -func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession, - route *route.Route, paymentHash [32]byte) ([32]byte, bool, error) { - - log.Tracef("Attempting to send payment %x, using route: %v", - paymentHash, newLogClosure(func() string { - return spew.Sdump(route) - }), - ) - - // Generate a new key to be used for this attempt. - sessionKey, err := generateNewSessionKey() - if err != nil { - return [32]byte{}, true, err - } - // Generate the raw encoded sphinx packet to be included along - // with the htlcAdd message that we send directly to the - // switch. - onionBlob, circuit, err := generateSphinxPacket( - route, paymentHash[:], sessionKey, - ) - if err != nil { - return [32]byte{}, true, err + // Now set up a paymentLifecycle struct with these params, such that we + // can resume the payment from the current state. + p := &paymentLifecycle{ + router: r, + payment: payment, + paySession: paySession, + timeoutChan: timeoutChan, + currentHeight: currentHeight, + finalCLTVDelta: finalCLTVDelta, + attempt: existingAttempt, + circuit: nil, + lastError: nil, } - // Craft an HTLC packet to send to the layer 2 switch. The - // metadata within this packet will be used to route the - // payment through the network, starting with the first-hop. - htlcAdd := &lnwire.UpdateAddHTLC{ - Amount: route.TotalAmount, - Expiry: route.TotalTimeLock, - PaymentHash: paymentHash, - } - copy(htlcAdd.OnionBlob[:], onionBlob) + return p.resumePayment() - // Attempt to send this payment through the network to complete - // the payment. If this attempt fails, then we'll continue on - // to the next available route. - firstHop := lnwire.NewShortChanIDFromInt( - route.Hops[0].ChannelID, - ) - - // We generate a new, unique payment ID that we will use for - // this HTLC. - paymentID, err := r.cfg.NextPaymentID() - if err != nil { - return [32]byte{}, true, err - } - - err = r.cfg.Payer.SendHTLC( - firstHop, paymentID, htlcAdd, - ) - if err != nil { - log.Errorf("Failed sending attempt %d for payment %x to "+ - "switch: %v", paymentID, paymentHash, err) - - // We must inspect the error to know whether it was critical or - // not, to decide whether we should continue trying. - finalOutcome := r.processSendError( - paySession, route, err, - ) - - return [32]byte{}, finalOutcome, err - } - - // Using the created circuit, initialize the error decrypter so we can - // parse+decode any failures incurred by this payment within the - // switch. - errorDecryptor := &htlcswitch.SphinxErrorDecrypter{ - OnionErrorDecrypter: sphinx.NewOnionErrorDecrypter(circuit), - } - - // Now ask the switch to return the result of the payment when - // available. - resultChan, err := r.cfg.Payer.GetPaymentResult( - paymentID, errorDecryptor, - ) - if err != nil { - log.Errorf("Failed getting result for paymentID %d "+ - "from switch: %v", paymentID, err) - return [32]byte{}, true, err - } - - var ( - result *htlcswitch.PaymentResult - ok bool - ) - select { - case result, ok = <-resultChan: - if !ok { - return [32]byte{}, true, htlcswitch.ErrSwitchExiting - } - - case <-r.quit: - return [32]byte{}, true, ErrRouterShuttingDown - } - - if result.Error != nil { - log.Errorf("Attempt to send payment %x failed: %v", - paymentHash, result.Error) - - finalOutcome := r.processSendError( - paySession, route, result.Error, - ) - - return [32]byte{}, finalOutcome, result.Error - } - - return result.Preimage, true, nil } // processSendError analyzes the error for the payment attempt received from the @@ -1789,7 +1739,7 @@ func (r *ChannelRouter) sendPaymentAttempt(paySession *paymentSession, // error type, this error is either the final outcome of the payment or we need // to continue with an alternative route. This is indicated by the boolean // return value. -func (r *ChannelRouter) processSendError(paySession *paymentSession, +func (r *ChannelRouter) processSendError(paySession PaymentSession, rt *route.Route, err error) bool { fErr, ok := err.(*htlcswitch.ForwardingError) diff --git a/routing/router_test.go b/routing/router_test.go index 9dc0eed1..9dd2e85e 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -50,6 +50,7 @@ func (c *testCtx) RestartRouter() error { Chain: c.chain, ChainView: c.chainView, Payer: &mockPaymentAttemptDispatcher{}, + Control: makeMockControlTower(), ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, }) @@ -83,11 +84,25 @@ func createTestCtxFromGraphInstance(startingHeight uint32, graphInstance *testGr // be populated. chain := newMockChain(startingHeight) chainView := newMockChainView(chain) + + selfNode, err := graphInstance.graph.SourceNode() + if err != nil { + return nil, nil, err + } + + mc := NewMissionControl( + graphInstance.graph, selfNode, + func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { + return lnwire.NewMSatFromSatoshis(e.Capacity) + }, + ) router, err := New(Config{ Graph: graphInstance.graph, Chain: chain, ChainView: chainView, Payer: &mockPaymentAttemptDispatcher{}, + Control: makeMockControlTower(), + MissionControl: mc, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { @@ -716,7 +731,9 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { }) // Once again, Roasbeef should route around Goku since they disagree - // w.r.t to the block height, and instead go through Pham Nuwen. + // w.r.t to the block height, and instead go through Pham Nuwen. We + // flip a bit in the payment hash to allow resending this payment. + payment.PaymentHash[1] ^= 1 paymentPreImage, rt, err = ctx.router.SendPayment(&payment) if err != nil { t.Fatalf("unable to send payment: %v", err) @@ -805,7 +822,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { return preImage, nil }) - ctx.router.missionControl.ResetHistory() + ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory() // When we try to dispatch that payment, we should receive an error as // both attempts should fail and cause both routes to be pruned. @@ -820,7 +837,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { t.Fatalf("expected UnknownNextPeer instead got: %v", err) } - ctx.router.missionControl.ResetHistory() + ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory() // Next, we'll modify the SendToSwitch method to indicate that luo ji // wasn't originally online. This should also halt the send all @@ -863,7 +880,7 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { ctx.aliases)) } - ctx.router.missionControl.ResetHistory() + ctx.router.cfg.MissionControl.(*MissionControl).ResetHistory() // Finally, we'll modify the SendToSwitch function to indicate that the // roasbeef -> luoji channel has insufficient capacity. This should @@ -883,6 +900,8 @@ func TestSendPaymentErrorPathPruning(t *testing.T) { return preImage, nil }) + // We flip a bit in the payment hash to allow resending this payment. + payment.PaymentHash[1] ^= 1 paymentPreImage, rt, err = ctx.router.SendPayment(&payment) if err != nil { t.Fatalf("unable to send payment: %v", err) @@ -1528,6 +1547,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { Chain: ctx.chain, ChainView: ctx.chainView, Payer: &mockPaymentAttemptDispatcher{}, + Control: makeMockControlTower(), ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, }) @@ -2490,3 +2510,629 @@ func assertChannelsPruned(t *testing.T, graph *channeldb.ChannelGraph, } } } + +// TestRouterPaymentStateMachine tests that the router interacts as expected +// with the ControlTower during a payment lifecycle, such that it payment +// attempts are not sent twice to the switch, and results are handled after a +// restart. +func TestRouterPaymentStateMachine(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + + // Setup two simple channels such that we can mock sending along this + // route. + chanCapSat := btcutil.Amount(100000) + testChannels := []*testChannel{ + symmetricTestChannel("a", "b", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 1), + symmetricTestChannel("b", "c", chanCapSat, &testChannelPolicy{ + Expiry: 144, + FeeRate: 400, + MinHTLC: 1, + MaxHTLC: lnwire.NewMSatFromSatoshis(chanCapSat), + }, 2), + } + + testGraph, err := createTestGraphFromChannels(testChannels) + if err != nil { + t.Fatalf("unable to create graph: %v", err) + } + defer testGraph.cleanUp() + + hop1 := testGraph.aliasMap["b"] + hop2 := testGraph.aliasMap["c"] + hops := []*route.Hop{ + { + ChannelID: 1, + PubKeyBytes: hop1, + }, + { + ChannelID: 2, + PubKeyBytes: hop2, + }, + } + + // We create a simple route that we will supply every time the router + // requests one. + rt, err := route.NewRouteFromHops( + lnwire.MilliSatoshi(10000), 100, testGraph.aliasMap["a"], hops, + ) + if err != nil { + t.Fatalf("unable to create route: %v", err) + } + + // A payment state machine test case consists of several ordered steps, + // that we use for driving the scenario. + type testCase struct { + // steps is a list of steps to perform during the testcase. + steps []string + + // routes is the sequence of routes we will provide to the + // router when it requests a new route. + routes []*route.Route + } + + const ( + // routerInitPayment is a test step where we expect the router + // to call the InitPayment method on the control tower. + routerInitPayment = "Router:init-payment" + + // routerRegisterAttempt is a test step where we expect the + // router to call the RegisterAttempt method on the control + // tower. + routerRegisterAttempt = "Router:register-attempt" + + // routerSuccess is a test step where we expect the router to + // call the Success method on the control tower. + routerSuccess = "Router:success" + + // routerFail is a test step where we expect the router to call + // the Fail method on the control tower. + routerFail = "Router:fail" + + // sendToSwitchSuccess is a step where we expect the router to + // call send the payment attempt to the switch, and we will + // respond with a non-error, indicating that the payment + // attempt was successfully forwarded. + sendToSwitchSuccess = "SendToSwitch:success" + + // sendToSwitchResultFailure is a step where we expect the + // router to send the payment attempt to the switch, and we + // will respond with a forwarding error. This can happen when + // forwarding fail on our local links. + sendToSwitchResultFailure = "SendToSwitch:failure" + + // getPaymentResultSuccess is a test step where we expect the + // router to call the GetPaymentResult method, and we will + // respond with a successful payment result. + getPaymentResultSuccess = "GetPaymentResult:success" + + // getPaymentResultFailure is a test step where we expect the + // router to call the GetPaymentResult method, and we will + // respond with a forwarding error. + getPaymentResultFailure = "GetPaymentResult:failure" + + // resendPayment is a test step where we manually try to resend + // the same payment, making sure the router responds with an + // error indicating that it is alreayd in flight. + resendPayment = "ResendPayment" + + // startRouter is a step where we manually start the router, + // used to test that it automatically will resume payments at + // startup. + startRouter = "StartRouter" + + // stopRouter is a test step where we manually make the router + // shut down. + stopRouter = "StopRouter" + + // paymentSuccess is a step where assert that we receive a + // successful result for the original payment made. + paymentSuccess = "PaymentSuccess" + + // paymentError is a step where assert that we receive an error + // for the original payment made. + paymentError = "PaymentError" + + // resentPaymentSuccess is a step where assert that we receive + // a successful result for a payment that was resent. + resentPaymentSuccess = "ResentPaymentSuccess" + + // resentPaymentError is a step where assert that we receive an + // error for a payment that was resent. + resentPaymentError = "ResentPaymentError" + ) + + tests := []testCase{ + { + // Tests a normal payment flow that succeeds. + steps: []string{ + routerInitPayment, + routerRegisterAttempt, + sendToSwitchSuccess, + getPaymentResultSuccess, + routerSuccess, + paymentSuccess, + }, + routes: []*route.Route{rt}, + }, + { + // A payment flow with a failure on the first attempt, + // but that succeeds on the second attempt. + steps: []string{ + routerInitPayment, + routerRegisterAttempt, + sendToSwitchSuccess, + + // Make the first sent attempt fail. + getPaymentResultFailure, + + // The router should retry. + routerRegisterAttempt, + sendToSwitchSuccess, + + // Make the second sent attempt succeed. + getPaymentResultSuccess, + routerSuccess, + paymentSuccess, + }, + routes: []*route.Route{rt, rt}, + }, + { + // A payment flow with a forwarding failure first time + // sending to the switch, but that succeeds on the + // second attempt. + steps: []string{ + routerInitPayment, + routerRegisterAttempt, + + // Make the first sent attempt fail. + sendToSwitchResultFailure, + + // The router should retry. + routerRegisterAttempt, + sendToSwitchSuccess, + + // Make the second sent attempt succeed. + getPaymentResultSuccess, + routerSuccess, + paymentSuccess, + }, + routes: []*route.Route{rt, rt}, + }, + { + // A payment that fails on the first attempt, and has + // only one route available to try. It will therefore + // fail permanently. + steps: []string{ + routerInitPayment, + routerRegisterAttempt, + sendToSwitchSuccess, + + // Make the first sent attempt fail. + getPaymentResultFailure, + + // Since there are no more routes to try, the + // payment should fail. + routerFail, + paymentError, + }, + routes: []*route.Route{rt}, + }, + { + // We expect the payment to fail immediately if we have + // no routes to try. + steps: []string{ + routerInitPayment, + routerFail, + paymentError, + }, + routes: []*route.Route{}, + }, + { + // A normal payment flow, where we attempt to resend + // the same payment after each step. This ensures that + // the router don't attempt to resend a payment already + // in flight. + steps: []string{ + routerInitPayment, + routerRegisterAttempt, + + // Manually resend the payment, the router + // should attempt to init with the control + // tower, but fail since it is already in + // flight. + resendPayment, + routerInitPayment, + resentPaymentError, + + // The original payment should proceed as + // normal. + sendToSwitchSuccess, + + // Again resend the payment and assert it's not + // allowed. + resendPayment, + routerInitPayment, + resentPaymentError, + + // Notify about a success for the original + // payment. + getPaymentResultSuccess, + routerSuccess, + + // Now that the original payment finished, + // resend it again to ensure this is not + // allowed. + resendPayment, + routerInitPayment, + resentPaymentError, + paymentSuccess, + }, + routes: []*route.Route{rt}, + }, + { + // Tests that the router is able to handle the + // receieved payment result after a restart. + steps: []string{ + routerInitPayment, + routerRegisterAttempt, + sendToSwitchSuccess, + + // Shut down the router. The original caller + // should get notified about this. + stopRouter, + paymentError, + + // Start the router again, and ensure the + // router registers the success with the + // control tower. + startRouter, + getPaymentResultSuccess, + routerSuccess, + }, + routes: []*route.Route{rt}, + }, + { + // Tests that we are allowed to resend a payment after + // it has permanently failed. + steps: []string{ + routerInitPayment, + routerRegisterAttempt, + sendToSwitchSuccess, + + // Resending the payment at this stage should + // not be allowed. + resendPayment, + routerInitPayment, + resentPaymentError, + + // Make the first attempt fail. + getPaymentResultFailure, + routerFail, + + // Since we have no more routes to try, the + // original payment should fail. + paymentError, + + // Now resend the payment again. This should be + // allowed, since the payment has failed. + resendPayment, + routerInitPayment, + routerRegisterAttempt, + sendToSwitchSuccess, + getPaymentResultSuccess, + routerSuccess, + resentPaymentSuccess, + }, + routes: []*route.Route{rt}, + }, + } + + // Create a mock control tower with channels set up, that we use to + // synchronize and listen for events. + control := makeMockControlTower() + control.init = make(chan initArgs) + control.register = make(chan registerArgs) + control.success = make(chan successArgs) + control.fail = make(chan failArgs) + control.fetchInFlight = make(chan struct{}) + + quit := make(chan struct{}) + defer close(quit) + + // setupRouter is a helper method that creates and starts the router in + // the desired configuration for this test. + setupRouter := func() (*ChannelRouter, chan error, + chan *htlcswitch.PaymentResult, chan error) { + + chain := newMockChain(startingBlockHeight) + chainView := newMockChainView(chain) + + // We set uo the use the following channels and a mock Payer to + // synchonize with the interaction to the Switch. + sendResult := make(chan error) + paymentResultErr := make(chan error) + paymentResult := make(chan *htlcswitch.PaymentResult) + + payer := &mockPayer{ + sendResult: sendResult, + paymentResult: paymentResult, + paymentResultErr: paymentResultErr, + } + + router, err := New(Config{ + Graph: testGraph.graph, + Chain: chain, + ChainView: chainView, + Control: control, + MissionControl: &mockPaymentSessionSource{}, + Payer: payer, + ChannelPruneExpiry: time.Hour * 24, + GraphPruneInterval: time.Hour * 2, + QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { + return lnwire.NewMSatFromSatoshis(e.Capacity) + }, + NextPaymentID: func() (uint64, error) { + next := atomic.AddUint64(&uniquePaymentID, 1) + return next, nil + }, + }) + if err != nil { + t.Fatalf("unable to create router %v", err) + } + + // On startup, the router should fetch all pending payments + // from the ControlTower, so assert that here. + didFetch := make(chan struct{}) + go func() { + select { + case <-control.fetchInFlight: + close(didFetch) + case <-time.After(1 * time.Second): + t.Fatalf("router did not fetch in flight " + + "payments") + } + }() + + if err := router.Start(); err != nil { + t.Fatalf("unable to start router: %v", err) + } + + select { + case <-didFetch: + case <-time.After(1 * time.Second): + t.Fatalf("did not fetch in flight payments at startup") + } + + return router, sendResult, paymentResult, paymentResultErr + } + + router, sendResult, getPaymentResult, getPaymentResultErr := setupRouter() + defer router.Stop() + + for _, test := range tests { + // Craft a LightningPayment struct. + var preImage lntypes.Preimage + if _, err := rand.Read(preImage[:]); err != nil { + t.Fatalf("unable to generate preimage") + } + + payHash := preImage.Hash() + + paymentAmt := lnwire.NewMSatFromSatoshis(1000) + payment := LightningPayment{ + Target: testGraph.aliasMap["c"], + Amount: paymentAmt, + FeeLimit: noFeeLimit, + PaymentHash: payHash, + } + + errSource, err := btcec.ParsePubKey(hop1[:], btcec.S256()) + if err != nil { + t.Fatalf("unable to fetch source node pub: %v", err) + } + + copy(preImage[:], bytes.Repeat([]byte{9}, 32)) + + router.cfg.MissionControl = &mockPaymentSessionSource{ + routes: test.routes, + } + + // Send the payment. Since this is new payment hash, the + // information should be registered with the ControlTower. + paymentResult := make(chan error) + go func() { + _, _, err := router.SendPayment(&payment) + paymentResult <- err + }() + + var resendResult chan error + for _, step := range test.steps { + switch step { + + case routerInitPayment: + var args initArgs + select { + case args = <-control.init: + case <-time.After(1 * time.Second): + t.Fatalf("no init payment with control") + } + + if args.c == nil { + t.Fatalf("expected non-nil CreationInfo") + } + + // In this step we expect the router to make a call to + // register a new attempt with the ControlTower. + case routerRegisterAttempt: + var args registerArgs + select { + case args = <-control.register: + case <-time.After(1 * time.Second): + t.Fatalf("not registered with control") + } + + if args.a == nil { + t.Fatalf("expected non-nil AttemptInfo") + } + + // In this step we expect the router to call the + // ControlTower's Succcess method with the preimage. + case routerSuccess: + select { + case _ = <-control.success: + case <-time.After(1 * time.Second): + t.Fatalf("not registered with control") + } + + // In this step we expect the router to call the + // ControlTower's Fail method, to indicate that the + // payment failed. + case routerFail: + select { + case _ = <-control.fail: + case <-time.After(1 * time.Second): + t.Fatalf("not registered with control") + } + + // In this step we expect the SendToSwitch method to be + // called, and we respond with a nil-error. + case sendToSwitchSuccess: + select { + case sendResult <- nil: + case <-time.After(1 * time.Second): + t.Fatalf("unable to send result") + } + + // In this step we expect the SendToSwitch method to be + // called, and we respond with a forwarding error + case sendToSwitchResultFailure: + select { + case sendResult <- &htlcswitch.ForwardingError{ + ErrorSource: errSource, + FailureMessage: &lnwire.FailTemporaryChannelFailure{}, + }: + case <-time.After(1 * time.Second): + t.Fatalf("unable to send result") + } + + // In this step we expect the GetPaymentResult method + // to be called, and we respond with the preimage to + // complete the payment. + case getPaymentResultSuccess: + select { + case getPaymentResult <- &htlcswitch.PaymentResult{ + Preimage: preImage, + }: + case <-time.After(1 * time.Second): + t.Fatalf("unable to send result") + } + + // In this state we expect the GetPaymentResult method + // to be called, and we respond with a forwarding + // error, indicating that the router should retry. + case getPaymentResultFailure: + select { + case getPaymentResult <- &htlcswitch.PaymentResult{ + Error: &htlcswitch.ForwardingError{ + ErrorSource: errSource, + FailureMessage: &lnwire.FailTemporaryChannelFailure{}, + }, + }: + case <-time.After(1 * time.Second): + t.Fatalf("unable to get result") + } + + // In this step we manually try to resend the same + // payment, making sure the router responds with an + // error indicating that it is alreayd in flight. + case resendPayment: + resendResult = make(chan error) + go func() { + _, _, err := router.SendPayment(&payment) + resendResult <- err + }() + + // In this step we manually stop the router. + case stopRouter: + select { + case getPaymentResultErr <- fmt.Errorf( + "shutting down"): + case <-time.After(1 * time.Second): + t.Fatalf("unable to send payment " + + "result error") + } + + if err := router.Stop(); err != nil { + t.Fatalf("unable to restart: %v", err) + } + + // In this step we manually start the router. + case startRouter: + router, sendResult, getPaymentResult, + getPaymentResultErr = setupRouter() + + // In this state we expect to receive an error for the + // original payment made. + case paymentError: + select { + case err := <-paymentResult: + if err == nil { + t.Fatalf("expected error") + } + + case <-time.After(1 * time.Second): + t.Fatalf("got no payment result") + } + + // In this state we expect the original payment to + // succeed. + case paymentSuccess: + select { + case err := <-paymentResult: + if err != nil { + t.Fatalf("did not expecte error %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("got no payment result") + } + + // In this state we expect to receive an error for the + // resent payment made. + case resentPaymentError: + select { + case err := <-resendResult: + if err == nil { + t.Fatalf("expected error") + } + + case <-time.After(1 * time.Second): + t.Fatalf("got no payment result") + } + + // In this state we expect the resent payment to + // succeed. + case resentPaymentSuccess: + select { + case err := <-resendResult: + if err != nil { + t.Fatalf("did not expect error %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("got no payment result") + } + + default: + t.Fatalf("unknown step %v", step) + } + } + } +} diff --git a/rpcserver.go b/rpcserver.go index 4fd091b1..e5910d27 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -2,7 +2,6 @@ package lnd import ( "bytes" - "crypto/sha256" "crypto/tls" "encoding/hex" "errors" @@ -2738,33 +2737,6 @@ func (r *rpcServer) SubscribeChannelEvents(req *lnrpc.ChannelEventSubscription, } } -// savePayment saves a successfully completed payment to the database for -// historical record keeping. -func (r *rpcServer) savePayment(route *route.Route, - amount lnwire.MilliSatoshi, preImage []byte) error { - - paymentPath := make([][33]byte, len(route.Hops)) - for i, hop := range route.Hops { - hopPub := hop.PubKeyBytes - copy(paymentPath[i][:], hopPub[:]) - } - - payment := &channeldb.OutgoingPayment{ - Invoice: channeldb.Invoice{ - Terms: channeldb.ContractTerm{ - Value: amount, - }, - CreationDate: time.Now(), - }, - Path: paymentPath, - Fee: route.TotalFees(), - TimeLockLength: route.TotalTimeLock, - } - copy(payment.PaymentPreimage[:], preImage) - - return r.server.chanDB.AddPayment(payment) -} - // validatePayReqExpiry checks if the passed payment request has expired. In // the case it has expired, an error will be returned. func validatePayReqExpiry(payReq *zpay32.Invoice) error { @@ -3126,18 +3098,6 @@ func (r *rpcServer) dispatchPaymentIntent( }, nil } - // Calculate amount paid to receiver. - amt := route.TotalAmount - route.TotalFees() - - // Save the completed payment to the database for record keeping - // purposes. - err := r.savePayment(route, amt, preImage[:]) - if err != nil { - // We weren't able to save the payment, so we return the save - // err, but a nil routing err. - return nil, err - } - return &paymentIntentResponse{ Route: route, Preimage: preImage, @@ -4149,8 +4109,8 @@ func (r *rpcServer) ListPayments(ctx context.Context, rpcsLog.Debugf("[ListPayments]") - payments, err := r.server.chanDB.FetchAllPayments() - if err != nil && err != channeldb.ErrNoPaymentsCreated { + payments, err := r.server.chanDB.FetchPayments() + if err != nil { return nil, err } @@ -4158,24 +4118,37 @@ func (r *rpcServer) ListPayments(ctx context.Context, Payments: make([]*lnrpc.Payment, len(payments)), } for i, payment := range payments { - path := make([]string, len(payment.Path)) - for i, hop := range payment.Path { - path[i] = hex.EncodeToString(hop[:]) + // If a payment attempt has been made we can fetch the route. + // Otherwise we'll just populate the RPC response with an empty + // one. + var route route.Route + if payment.Attempt != nil { + route = payment.Attempt.Route + } + path := make([]string, len(route.Hops)) + for i, hop := range route.Hops { + path[i] = hex.EncodeToString(hop.PubKeyBytes[:]) } - msatValue := int64(payment.Terms.Value) - satValue := int64(payment.Terms.Value.ToSatoshis()) + // If this payment is settled, the preimage will be available. + var preimage lntypes.Preimage + if payment.PaymentPreimage != nil { + preimage = *payment.PaymentPreimage + } - paymentHash := sha256.Sum256(payment.PaymentPreimage[:]) + msatValue := int64(payment.Info.Value) + satValue := int64(payment.Info.Value.ToSatoshis()) + + paymentHash := payment.Info.PaymentHash paymentsResp.Payments[i] = &lnrpc.Payment{ PaymentHash: hex.EncodeToString(paymentHash[:]), Value: satValue, ValueMsat: msatValue, ValueSat: satValue, - CreationDate: payment.CreationDate.Unix(), + CreationDate: payment.Info.CreationDate.Unix(), Path: path, - Fee: int64(payment.Fee.ToSatoshis()), - PaymentPreimage: hex.EncodeToString(payment.PaymentPreimage[:]), + Fee: int64(route.TotalFees().ToSatoshis()), + PaymentPreimage: hex.EncodeToString(preimage[:]), } } @@ -4188,7 +4161,7 @@ func (r *rpcServer) DeleteAllPayments(ctx context.Context, rpcsLog.Debugf("[DeleteAllPayments]") - if err := r.server.chanDB.DeleteAllPayments(); err != nil { + if err := r.server.chanDB.DeletePayments(); err != nil { return nil, err } diff --git a/server.go b/server.go index 06a1b150..93f99737 100644 --- a/server.go +++ b/server.go @@ -616,42 +616,50 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl, return nil, err } + queryBandwidth := func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { + // If we aren't on either side of this edge, then we'll + // just thread through the capacity of the edge as we + // know it. + if !bytes.Equal(edge.NodeKey1Bytes[:], selfNode.PubKeyBytes[:]) && + !bytes.Equal(edge.NodeKey2Bytes[:], selfNode.PubKeyBytes[:]) { + + return lnwire.NewMSatFromSatoshis(edge.Capacity) + } + + cid := lnwire.NewChanIDFromOutPoint(&edge.ChannelPoint) + link, err := s.htlcSwitch.GetLink(cid) + if err != nil { + // If the link isn't online, then we'll report + // that it has zero bandwidth to the router. + return 0 + } + + // If the link is found within the switch, but it isn't + // yet eligible to forward any HTLCs, then we'll treat + // it as if it isn't online in the first place. + if !link.EligibleToForward() { + return 0 + } + + // Otherwise, we'll return the current best estimate + // for the available bandwidth for the link. + return link.Bandwidth() + } + + missionControl := routing.NewMissionControl( + chanGraph, selfNode, queryBandwidth, + ) + s.chanRouter, err = routing.New(routing.Config{ Graph: chanGraph, Chain: cc.chainIO, ChainView: cc.chainView, Payer: s.htlcSwitch, + Control: channeldb.NewPaymentControl(chanDB), + MissionControl: missionControl, ChannelPruneExpiry: routing.DefaultChannelPruneExpiry, GraphPruneInterval: time.Duration(time.Hour), - QueryBandwidth: func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { - // If we aren't on either side of this edge, then we'll - // just thread through the capacity of the edge as we - // know it. - if !bytes.Equal(edge.NodeKey1Bytes[:], selfNode.PubKeyBytes[:]) && - !bytes.Equal(edge.NodeKey2Bytes[:], selfNode.PubKeyBytes[:]) { - - return lnwire.NewMSatFromSatoshis(edge.Capacity) - } - - cid := lnwire.NewChanIDFromOutPoint(&edge.ChannelPoint) - link, err := s.htlcSwitch.GetLink(cid) - if err != nil { - // If the link isn't online, then we'll report - // that it has zero bandwidth to the router. - return 0 - } - - // If the link is found within the switch, but it isn't - // yet eligible to forward any HTLCs, then we'll treat - // it as if it isn't online in the first place. - if !link.EligibleToForward() { - return 0 - } - - // Otherwise, we'll return the current best estimate - // for the available bandwidth for the link. - return link.Bandwidth() - }, + QueryBandwidth: queryBandwidth, AssumeChannelValid: cfg.Routing.UseAssumeChannelValid(), NextPaymentID: sequencer.NextID, })