Merge pull request #1719 from cfromknecht/duplicate-payments

Control Tower: Sender-side checks for duplicate payments #2
This commit is contained in:
Olaoluwa Osuntokun 2018-08-22 21:22:38 -07:00 committed by GitHub
commit 4f20905ac1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1334 additions and 177 deletions

@ -67,6 +67,13 @@ var (
number: 4, number: 4,
migration: migrateEdgePolicies, migration: migrateEdgePolicies,
}, },
{
// The DB version where we persist each attempt to send
// an HTLC to a payment hash, and track whether the
// payment is in-flight, succeeded, or failed.
number: 5,
migration: paymentStatusesMigration,
},
} }
// Big endian is the preferred byte order, due to cursor scans over // Big endian is the preferred byte order, due to cursor scans over

@ -9,6 +9,59 @@ import (
"github.com/go-errors/errors" "github.com/go-errors/errors"
) )
// applyMigration is a helper test function that encapsulates the general steps
// which are needed to properly check the result of applying migration function.
func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB),
migrationFunc migration, shouldFail bool) {
cdb, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatal(err)
}
// beforeMigration usually used for populating the database
// with test data.
beforeMigration(cdb)
// Create test meta info with zero database version and put it on disk.
// Than creating the version list pretending that new version was added.
meta := &Meta{DbVersionNumber: 0}
if err := cdb.PutMeta(meta); err != nil {
t.Fatalf("unable to store meta data: %v", err)
}
versions := []version{
{
number: 0,
migration: nil,
},
{
number: 1,
migration: migrationFunc,
},
}
defer func() {
if r := recover(); r != nil {
err = errors.New(r)
}
if err == nil && shouldFail {
t.Fatal("error wasn't received on migration stage")
} else if err != nil && !shouldFail {
t.Fatal("error was received on migration stage")
}
// afterMigration usually used for checking the database state and
// throwing the error if something went wrong.
afterMigration(cdb)
}()
// Sync with the latest version - applying migration function.
err = cdb.syncVersions(versions)
}
// TestVersionFetchPut checks the propernces of fetch/put methods // TestVersionFetchPut checks the propernces of fetch/put methods
// and also initialization of meta data in case if don't have any in // and also initialization of meta data in case if don't have any in
// database. // database.
@ -118,59 +171,8 @@ func TestGlobalVersionList(t *testing.T) {
} }
} }
// applyMigration is a helper test function that encapsulates the general steps // TestMigrationWithPanic asserts that if migration logic panics, we will return
// which are needed to properly check the result of applying migration function. // to the original state unaltered.
func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB),
migrationFunc migration, shouldFail bool) {
cdb, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatal(err)
}
// beforeMigration usually used for populating the database
// with test data.
beforeMigration(cdb)
// Create test meta info with zero database version and put it on disk.
// Than creating the version list pretending that new version was added.
meta := &Meta{DbVersionNumber: 0}
if err := cdb.PutMeta(meta); err != nil {
t.Fatalf("unable to store meta data: %v", err)
}
versions := []version{
{
number: 0,
migration: nil,
},
{
number: 1,
migration: migrationFunc,
},
}
defer func() {
if r := recover(); r != nil {
err = errors.New(r)
}
if err == nil && shouldFail {
t.Fatal("error wasn't received on migration stage")
} else if err != nil && !shouldFail {
t.Fatal("error was received on migration stage")
}
// afterMigration usually used for checking the database state and
// throwing the error if something went wrong.
afterMigration(cdb)
}()
// Sync with the latest version - applying migration function.
err = cdb.syncVersions(versions)
}
func TestMigrationWithPanic(t *testing.T) { func TestMigrationWithPanic(t *testing.T) {
t.Parallel() t.Parallel()
@ -242,6 +244,8 @@ func TestMigrationWithPanic(t *testing.T) {
true) true)
} }
// TestMigrationWithFatal asserts that migrations which fail do not modify the
// database.
func TestMigrationWithFatal(t *testing.T) { func TestMigrationWithFatal(t *testing.T) {
t.Parallel() t.Parallel()
@ -312,6 +316,8 @@ func TestMigrationWithFatal(t *testing.T) {
true) true)
} }
// TestMigrationWithoutErrors asserts that a successful migration has its
// changes applied to the database.
func TestMigrationWithoutErrors(t *testing.T) { func TestMigrationWithoutErrors(t *testing.T) {
t.Parallel() t.Parallel()

@ -2,6 +2,8 @@ package channeldb
import ( import (
"bytes" "bytes"
"crypto/sha256"
"encoding/binary"
"fmt" "fmt"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
@ -373,3 +375,86 @@ func migrateEdgePolicies(tx *bolt.Tx) error {
return nil return nil
} }
// paymentStatusesMigration is a database migration intended for adding payment
// statuses for each existing payment entity in bucket to be able control
// transitions of statuses and prevent cases such as double payment
func paymentStatusesMigration(tx *bolt.Tx) error {
// Get the bucket dedicated to storing statuses of payments,
// where a key is payment hash, value is payment status.
paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket)
if err != nil {
return err
}
log.Infof("Migrating database to support payment statuses")
circuitAddKey := []byte("circuit-adds")
circuits := tx.Bucket(circuitAddKey)
if circuits != nil {
log.Infof("Marking all known circuits with status InFlight")
err = circuits.ForEach(func(k, v []byte) error {
// Parse the first 8 bytes as the short chan ID for the
// circuit. We'll skip all short chan IDs are not
// locally initiated, which includes all non-zero short
// chan ids.
chanID := binary.BigEndian.Uint64(k[:8])
if chanID != 0 {
return nil
}
// The payment hash is the third item in the serialized
// payment circuit. The first two items are an AddRef
// (10 bytes) and the incoming circuit key (16 bytes).
const payHashOffset = 10 + 16
paymentHash := v[payHashOffset : payHashOffset+32]
return paymentStatuses.Put(
paymentHash[:], StatusInFlight.Bytes(),
)
})
if err != nil {
return err
}
}
log.Infof("Marking all existing payments with status Completed")
// Get the bucket dedicated to storing payments
bucket := tx.Bucket(paymentBucket)
if bucket == nil {
return nil
}
// For each payment in the bucket, deserialize the payment and mark it
// as completed.
err = bucket.ForEach(func(k, v []byte) error {
// Ignores if it is sub-bucket.
if v == nil {
return nil
}
r := bytes.NewReader(v)
payment, err := deserializeOutgoingPayment(r)
if err != nil {
return err
}
// Calculate payment hash for current payment.
paymentHash := sha256.Sum256(payment.PaymentPreimage[:])
// Update status for current payment to completed. If it fails,
// the migration is aborted and the payment bucket is returned
// to its previous state.
return paymentStatuses.Put(paymentHash[:], StatusCompleted.Bytes())
})
if err != nil {
return err
}
log.Infof("Migration of payment statuses complete!")
return nil
}

@ -0,0 +1,191 @@
package channeldb
import (
"crypto/sha256"
"encoding/binary"
"testing"
"github.com/coreos/bbolt"
)
// TestPaymentStatusesMigration checks that already completed payments will have
// their payment statuses set to Completed after the migration.
func TestPaymentStatusesMigration(t *testing.T) {
t.Parallel()
fakePayment := makeFakePayment()
paymentHash := sha256.Sum256(fakePayment.PaymentPreimage[:])
// Add fake payment to test database, verifying that it was created,
// that we have only one payment, and its status is not "Completed".
beforeMigrationFunc := func(d *DB) {
if err := d.AddPayment(fakePayment); err != nil {
t.Fatalf("unable to add payment: %v", err)
}
payments, err := d.FetchAllPayments()
if err != nil {
t.Fatalf("unable to fetch payments: %v", err)
}
if len(payments) != 1 {
t.Fatalf("wrong qty of paymets: expected 1, got %v",
len(payments))
}
paymentStatus, err := d.FetchPaymentStatus(paymentHash)
if err != nil {
t.Fatalf("unable to fetch payment status: %v", err)
}
// We should receive default status if we have any in database.
if paymentStatus != StatusGrounded {
t.Fatalf("wrong payment status: expected %v, got %v",
StatusGrounded.String(), paymentStatus.String())
}
// Lastly, we'll add a locally-sourced circuit and
// non-locally-sourced circuit to the circuit map. The
// locally-sourced payment should end up with an InFlight
// status, while the other should remain unchanged, which
// defaults to Grounded.
err = d.Update(func(tx *bolt.Tx) error {
circuits, err := tx.CreateBucketIfNotExists(
[]byte("circuit-adds"),
)
if err != nil {
return err
}
groundedKey := make([]byte, 16)
binary.BigEndian.PutUint64(groundedKey[:8], 1)
binary.BigEndian.PutUint64(groundedKey[8:], 1)
// Generated using TestHalfCircuitSerialization with nil
// ErrorEncrypter, which is the case for locally-sourced
// payments. No payment status should end up being set
// for this circuit, since the short channel id of the
// key is non-zero (e.g., a forwarded circuit). This
// will default it to Grounded.
groundedCircuit := []byte{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01,
// start payment hash
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
// end payment hash
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x0f,
0x42, 0x40, 0x00,
}
err = circuits.Put(groundedKey, groundedCircuit)
if err != nil {
return err
}
inFlightKey := make([]byte, 16)
binary.BigEndian.PutUint64(inFlightKey[:8], 0)
binary.BigEndian.PutUint64(inFlightKey[8:], 1)
// Generated using TestHalfCircuitSerialization with nil
// ErrorEncrypter, which is not the case for forwarded
// payments, but should have no impact on the
// correctness of the test. The payment status for this
// circuit should be set to InFlight, since the short
// channel id in the key is 0 (sourceHop).
inFlightCircuit := []byte{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x01,
// start payment hash
0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
// end payment hash
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x0f,
0x42, 0x40, 0x00,
}
return circuits.Put(inFlightKey, inFlightCircuit)
})
if err != nil {
t.Fatalf("unable to add circuit map entry: %v", err)
}
}
// Verify that the created payment status is "Completed" for our one
// fake payment.
afterMigrationFunc := func(d *DB) {
meta, err := d.FetchMeta(nil)
if err != nil {
t.Fatal(err)
}
if meta.DbVersionNumber != 1 {
t.Fatal("migration 'paymentStatusesMigration' wasn't applied")
}
// Check that our completed payments were migrated.
paymentStatus, err := d.FetchPaymentStatus(paymentHash)
if err != nil {
t.Fatalf("unable to fetch payment status: %v", err)
}
if paymentStatus != StatusCompleted {
t.Fatalf("wrong payment status: expected %v, got %v",
StatusCompleted.String(), paymentStatus.String())
}
inFlightHash := [32]byte{
0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}
// Check that the locally sourced payment was transitioned to
// InFlight.
paymentStatus, err = d.FetchPaymentStatus(inFlightHash)
if err != nil {
t.Fatalf("unable to fetch payment status: %v", err)
}
if paymentStatus != StatusInFlight {
t.Fatalf("wrong payment status: expected %v, got %v",
StatusInFlight.String(), paymentStatus.String())
}
groundedHash := [32]byte{
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
}
// Check that non-locally sourced payments remain in the default
// Grounded state.
paymentStatus, err = d.FetchPaymentStatus(groundedHash)
if err != nil {
t.Fatalf("unable to fetch payment status: %v", err)
}
if paymentStatus != StatusGrounded {
t.Fatalf("wrong payment status: expected %v, got %v",
StatusGrounded.String(), paymentStatus.String())
}
}
applyMigration(t,
beforeMigrationFunc,
afterMigrationFunc,
paymentStatusesMigration,
false)
}

@ -3,6 +3,7 @@ package channeldb
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors"
"io" "io"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
@ -17,8 +18,65 @@ var (
// which is a monotonically increasing uint64. BoltDB's sequence // which is a monotonically increasing uint64. BoltDB's sequence
// feature is used for generating monotonically increasing id. // feature is used for generating monotonically increasing id.
paymentBucket = []byte("payments") paymentBucket = []byte("payments")
// paymentStatusBucket is the name of the bucket within the database that
// stores the status of a payment indexed by the payment's preimage.
paymentStatusBucket = []byte("payment-status")
) )
// PaymentStatus represent current status of payment
type PaymentStatus byte
const (
// StatusGrounded is the status where a payment has never been
// initiated, or has been initiated and received an intermittent
// failure.
StatusGrounded PaymentStatus = 0
// StatusInFlight is the status where a payment has been initiated, but
// a response has not been received.
StatusInFlight PaymentStatus = 1
// StatusCompleted is the status where a payment has been initiated and
// the payment was completed successfully.
StatusCompleted PaymentStatus = 2
)
// Bytes returns status as slice of bytes.
func (ps PaymentStatus) Bytes() []byte {
return []byte{byte(ps)}
}
// FromBytes sets status from slice of bytes.
func (ps *PaymentStatus) FromBytes(status []byte) error {
if len(status) != 1 {
return errors.New("payment status is empty")
}
switch PaymentStatus(status[0]) {
case StatusGrounded, StatusInFlight, StatusCompleted:
*ps = PaymentStatus(status[0])
default:
return errors.New("unknown payment status")
}
return nil
}
// String returns readable representation of payment status.
func (ps PaymentStatus) String() string {
switch ps {
case StatusGrounded:
return "Grounded"
case StatusInFlight:
return "In Flight"
case StatusCompleted:
return "Completed"
default:
return "Unknown"
}
}
// OutgoingPayment represents a successful payment between the daemon and a // OutgoingPayment represents a successful payment between the daemon and a
// remote node. Details such as the total fee paid, and the time of the payment // remote node. Details such as the total fee paid, and the time of the payment
// are stored. // are stored.
@ -129,6 +187,68 @@ func (db *DB) DeleteAllPayments() error {
}) })
} }
// UpdatePaymentStatus sets the payment status for outgoing/finished payments in
// local database.
func (db *DB) UpdatePaymentStatus(paymentHash [32]byte, status PaymentStatus) error {
return db.Batch(func(tx *bolt.Tx) error {
return UpdatePaymentStatusTx(tx, paymentHash, status)
})
}
// UpdatePaymentStatusTx is a helper method that sets the payment status for
// outgoing/finished payments in the local database. This method accepts a
// boltdb transaction such that the operation can be composed into other
// database transactions.
func UpdatePaymentStatusTx(tx *bolt.Tx,
paymentHash [32]byte, status PaymentStatus) error {
paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket)
if err != nil {
return err
}
return paymentStatuses.Put(paymentHash[:], status.Bytes())
}
// FetchPaymentStatus returns the payment status for outgoing payment.
// If status of the payment isn't found, it will default to "StatusGrounded".
func (db *DB) FetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) {
var paymentStatus = StatusGrounded
err := db.View(func(tx *bolt.Tx) error {
var err error
paymentStatus, err = FetchPaymentStatusTx(tx, paymentHash)
return err
})
if err != nil {
return StatusGrounded, err
}
return paymentStatus, nil
}
// FetchPaymentStatusTx is a helper method that returns the payment status for
// outgoing payment. If status of the payment isn't found, it will default to
// "StatusGrounded". It accepts the boltdb transactions such that this method
// can be composed into other atomic operations.
func FetchPaymentStatusTx(tx *bolt.Tx, paymentHash [32]byte) (PaymentStatus, error) {
// The default status for all payments that aren't recorded in database.
var paymentStatus = StatusGrounded
bucket := tx.Bucket(paymentStatusBucket)
if bucket == nil {
return paymentStatus, nil
}
paymentStatusBytes := bucket.Get(paymentHash[:])
if paymentStatusBytes == nil {
return paymentStatus, nil
}
paymentStatus.FromBytes(paymentStatusBytes)
return paymentStatus, nil
}
func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error { func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error {
var scratch [8]byte var scratch [8]byte

@ -40,6 +40,14 @@ func makeFakePayment() *OutgoingPayment {
return fakePayment return fakePayment
} }
func makeFakePaymentHash() [32]byte {
var paymentHash [32]byte
rBytes, _ := randomBytes(0, 32)
copy(paymentHash[:], rBytes)
return paymentHash
}
// randomBytes creates random []byte with length in range [minLen, maxLen) // randomBytes creates random []byte with length in range [minLen, maxLen)
func randomBytes(minLen, maxLen int) ([]byte, error) { func randomBytes(minLen, maxLen int) ([]byte, error) {
randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen)) randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen))
@ -195,3 +203,51 @@ func TestOutgoingPaymentWorkflow(t *testing.T) {
len(paymentsAfterDeletion), 0) len(paymentsAfterDeletion), 0)
} }
} }
func TestPaymentStatusWorkflow(t *testing.T) {
t.Parallel()
db, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}
testCases := []struct {
paymentHash [32]byte
status PaymentStatus
}{
{
paymentHash: makeFakePaymentHash(),
status: StatusGrounded,
},
{
paymentHash: makeFakePaymentHash(),
status: StatusInFlight,
},
{
paymentHash: makeFakePaymentHash(),
status: StatusCompleted,
},
}
for _, testCase := range testCases {
err := db.UpdatePaymentStatus(testCase.paymentHash, testCase.status)
if err != nil {
t.Fatalf("unable to put payment in DB: %v", err)
}
status, err := db.FetchPaymentStatus(testCase.paymentHash)
if err != nil {
t.Fatalf("unable to fetch payments from DB: %v", err)
}
if status != testCase.status {
t.Fatalf("Wrong payments status after reading from DB."+
"Got %v, want %v",
spew.Sdump(status),
spew.Sdump(testCase.status),
)
}
}
}

245
htlcswitch/control_tower.go Normal file

@ -0,0 +1,245 @@
package htlcswitch
import (
"errors"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrAlreadyPaid signals we have already paid this payment hash.
ErrAlreadyPaid = errors.New("invoice is already paid")
// ErrPaymentInFlight signals that payment for this payment hash is
// already "in flight" on the network.
ErrPaymentInFlight = errors.New("payment is in transition")
// ErrPaymentNotInitiated is returned if payment wasn't initiated in
// switch.
ErrPaymentNotInitiated = errors.New("payment isn't initiated")
// ErrPaymentAlreadyCompleted is returned in the event we attempt to
// recomplete a completed payment.
ErrPaymentAlreadyCompleted = errors.New("payment is already completed")
// ErrUnknownPaymentStatus is returned when we do not recognize the
// existing state of a payment.
ErrUnknownPaymentStatus = errors.New("unknown payment status")
)
// ControlTower tracks all outgoing payments made by the switch, whose primary
// purpose is to prevent duplicate payments to the same payment hash. In
// production, a persistent implementation is preferred so that tracking can
// survive across restarts. Payments are transition through various payment
// states, and the ControlTower interface provides access to driving the state
// transitions.
type ControlTower interface {
// ClearForTakeoff atomically checks that no inflight or completed
// payments exist for this payment hash. If none are found, this method
// atomically transitions the status for this payment hash as InFlight.
ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error
// Success transitions an InFlight payment into a Completed payment.
// After invoking this method, ClearForTakeoff should always return an
// error to prevent us from making duplicate payments to the same
// payment hash.
Success(paymentHash [32]byte) error
// Fail transitions an InFlight payment into a Grounded Payment. After
// invoking this method, ClearForTakeoff should return nil on its next
// call for this payment hash, allowing the switch to make a subsequent
// payment.
Fail(paymentHash [32]byte) error
}
// paymentControl is persistent implementation of ControlTower to restrict
// double payment sending.
type paymentControl struct {
strict bool
db *channeldb.DB
}
// NewPaymentControl creates a new instance of the paymentControl. The strict
// flag indicates whether the controller should require "strict" state
// transitions, which would be otherwise intolerant to older databases that may
// already have duplicate payments to the same payment hash. It should be
// enabled only after sufficient checks have been made to ensure the db does not
// contain such payments. In the meantime, non-strict mode enforces a superset
// of the state transitions that prevent additional payments to a given payment
// hash from being added.
func NewPaymentControl(strict bool, db *channeldb.DB) ControlTower {
return &paymentControl{
strict: strict,
db: db,
}
}
// ClearForTakeoff checks that we don't already have an InFlight or Completed
// payment identified by the same payment hash.
func (p *paymentControl) ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error {
var takeoffErr error
err := p.db.Batch(func(tx *bolt.Tx) error {
// Retrieve current status of payment from local database.
paymentStatus, err := channeldb.FetchPaymentStatusTx(
tx, htlc.PaymentHash,
)
if err != nil {
return err
}
// Reset the takeoff error, to avoid carrying over an error
// from a previous execution of the batched db transaction.
takeoffErr = nil
switch paymentStatus {
case channeldb.StatusGrounded:
// It is safe to reattempt a payment if we know that we
// haven't left one in flight. Since this one is
// grounded, Transition the payment status to InFlight
// to prevent others.
return channeldb.UpdatePaymentStatusTx(
tx, htlc.PaymentHash, channeldb.StatusInFlight,
)
case channeldb.StatusInFlight:
// We already have an InFlight payment on the network. We will
// disallow any more payment until a response is received.
takeoffErr = ErrPaymentInFlight
case channeldb.StatusCompleted:
// We've already completed a payment to this payment hash,
// forbid the switch from sending another.
takeoffErr = ErrAlreadyPaid
default:
takeoffErr = ErrUnknownPaymentStatus
}
return nil
})
if err != nil {
return err
}
return takeoffErr
}
// Success transitions an InFlight payment to Completed, otherwise it returns an
// error. After calling Success, ClearForTakeoff should prevent any further
// attempts for the same payment hash.
func (p *paymentControl) Success(paymentHash [32]byte) error {
var updateErr error
err := p.db.Batch(func(tx *bolt.Tx) error {
paymentStatus, err := channeldb.FetchPaymentStatusTx(
tx, paymentHash,
)
if err != nil {
return err
}
// Reset the update error, to avoid carrying over an error
// from a previous execution of the batched db transaction.
updateErr = nil
switch {
case paymentStatus == channeldb.StatusGrounded && p.strict:
// Our records show the payment as still being grounded,
// meaning it never should have left the switch.
updateErr = ErrPaymentNotInitiated
case paymentStatus == channeldb.StatusGrounded && !p.strict:
// Though our records show the payment as still being
// grounded, meaning it never should have left the
// switch, we permit this transition in non-strict mode
// to handle inconsistent db states.
fallthrough
case paymentStatus == channeldb.StatusInFlight:
// A successful response was received for an InFlight
// payment, mark it as completed to prevent sending to
// this payment hash again.
return channeldb.UpdatePaymentStatusTx(
tx, paymentHash, channeldb.StatusCompleted,
)
case paymentStatus == channeldb.StatusCompleted:
// The payment was completed previously, alert the
// caller that this may be a duplicate call.
updateErr = ErrPaymentAlreadyCompleted
default:
updateErr = ErrUnknownPaymentStatus
}
return nil
})
if err != nil {
return err
}
return updateErr
}
// Fail transitions an InFlight payment to Grounded, otherwise it returns an
// error. After calling Fail, ClearForTakeoff should fail any further attempts
// for the same payment hash.
func (p *paymentControl) Fail(paymentHash [32]byte) error {
var updateErr error
err := p.db.Batch(func(tx *bolt.Tx) error {
paymentStatus, err := channeldb.FetchPaymentStatusTx(
tx, paymentHash,
)
if err != nil {
return err
}
// Reset the update error, to avoid carrying over an error
// from a previous execution of the batched db transaction.
updateErr = nil
switch {
case paymentStatus == channeldb.StatusGrounded && p.strict:
// Our records show the payment as still being grounded,
// meaning it never should have left the switch.
updateErr = ErrPaymentNotInitiated
case paymentStatus == channeldb.StatusGrounded && !p.strict:
// Though our records show the payment as still being
// grounded, meaning it never should have left the
// switch, we permit this transition in non-strict mode
// to handle inconsistent db states.
fallthrough
case paymentStatus == channeldb.StatusInFlight:
// A failed response was received for an InFlight
// payment, mark it as Grounded again to allow
// subsequent attempts.
return channeldb.UpdatePaymentStatusTx(
tx, paymentHash, channeldb.StatusGrounded,
)
case paymentStatus == channeldb.StatusCompleted:
// The payment was completed previously, and we are now
// reporting that it has failed. Leave the status as
// completed, but alert the user that something is
// wrong.
updateErr = ErrPaymentAlreadyCompleted
default:
updateErr = ErrUnknownPaymentStatus
}
return nil
})
if err != nil {
return err
}
return updateErr
}

@ -0,0 +1,351 @@
package htlcswitch
import (
"fmt"
"testing"
"github.com/btcsuite/fastsha256"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
)
func genHtlc() (*lnwire.UpdateAddHTLC, error) {
preimage, err := genPreimage()
if err != nil {
return nil, fmt.Errorf("unable to generate preimage: %v", err)
}
rhash := fastsha256.Sum256(preimage[:])
htlc := &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
}
return htlc, nil
}
type paymentControlTestCase func(*testing.T, bool)
var paymentControlTests = []struct {
name string
strict bool
testcase paymentControlTestCase
}{
{
name: "fail-strict",
strict: true,
testcase: testPaymentControlSwitchFail,
},
{
name: "double-send-strict",
strict: true,
testcase: testPaymentControlSwitchDoubleSend,
},
{
name: "double-pay-strict",
strict: true,
testcase: testPaymentControlSwitchDoublePay,
},
{
name: "fail-not-strict",
strict: false,
testcase: testPaymentControlSwitchFail,
},
{
name: "double-send-not-strict",
strict: false,
testcase: testPaymentControlSwitchDoubleSend,
},
{
name: "double-pay-not-strict",
strict: false,
testcase: testPaymentControlSwitchDoublePay,
},
}
// TestPaymentControls runs a set of common tests against both the strict and
// non-strict payment control instances. This ensures that the two both behave
// identically when making the expected state-transitions of the stricter
// implementation. Behavioral differences in the strict and non-strict
// implementations are tested separately.
func TestPaymentControls(t *testing.T) {
for _, test := range paymentControlTests {
t.Run(test.name, func(t *testing.T) {
test.testcase(t, test.strict)
})
}
}
// testPaymentControlSwitchFail checks that payment status returns to Grounded
// status after failing, and that ClearForTakeoff allows another HTLC for the
// same payment hash.
func testPaymentControlSwitchFail(t *testing.T, strict bool) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(strict, db)
htlc, err := genHtlc()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
// Sends base htlc message which initiate StatusInFlight.
if err := pControl.ClearForTakeoff(htlc); err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight)
// Fail the payment, which should moved it to Grounded.
if err := pControl.Fail(htlc.PaymentHash); err != nil {
t.Fatalf("unable to fail payment hash: %v", err)
}
// Verify the status is indeed Grounded.
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded)
// Sends the htlc again, which should succeed since the prior payment
// failed.
if err := pControl.ClearForTakeoff(htlc); err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight)
// Verifies that status was changed to StatusCompleted.
if err := pControl.Success(htlc.PaymentHash); err != nil {
t.Fatalf("error shouldn't have been received, got: %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted)
// Attempt a final payment, which should now fail since the prior
// payment succeed.
if err := pControl.ClearForTakeoff(htlc); err != ErrAlreadyPaid {
t.Fatalf("unable to send htlc message: %v", err)
}
}
// testPaymentControlSwitchDoubleSend checks the ability of payment control to
// prevent double sending of htlc message, when message is in StatusInFlight.
func testPaymentControlSwitchDoubleSend(t *testing.T, strict bool) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(strict, db)
htlc, err := genHtlc()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
// Sends base htlc message which initiate base status and move it to
// StatusInFlight and verifies that it was changed.
if err := pControl.ClearForTakeoff(htlc); err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight)
// Try to initiate double sending of htlc message with the same
// payment hash, should result in error indicating that payment has
// already been sent.
if err := pControl.ClearForTakeoff(htlc); err != ErrPaymentInFlight {
t.Fatalf("payment control wrong behaviour: " +
"double sending must trigger ErrPaymentInFlight error")
}
}
// TestPaymentControlSwitchDoublePay checks the ability of payment control to
// prevent double payment.
func testPaymentControlSwitchDoublePay(t *testing.T, strict bool) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(strict, db)
htlc, err := genHtlc()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
// Sends base htlc message which initiate StatusInFlight.
if err := pControl.ClearForTakeoff(htlc); err != nil {
t.Fatalf("unable to send htlc message: %v", err)
}
// Verify that payment is InFlight.
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight)
// Move payment to completed status, second payment should return error.
if err := pControl.Success(htlc.PaymentHash); err != nil {
t.Fatalf("error shouldn't have been received, got: %v", err)
}
// Verify that payment is Completed.
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted)
if err := pControl.ClearForTakeoff(htlc); err != ErrAlreadyPaid {
t.Fatalf("payment control wrong behaviour:" +
" double payment must trigger ErrAlreadyPaid")
}
}
// TestPaymentControlNonStrictSuccessesWithoutInFlight checks that a non-strict
// payment control will allow calls to Success when no payment is in flight. This
// is necessary to gracefully handle the case in which the switch already sent
// out a payment for a particular payment hash in a prior db version that didn't
// have payment statuses.
func TestPaymentControlNonStrictSuccessesWithoutInFlight(t *testing.T) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(false, db)
htlc, err := genHtlc()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
if err := pControl.Success(htlc.PaymentHash); err != nil {
t.Fatalf("unable to mark payment hash success: %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted)
err = pControl.Success(htlc.PaymentHash)
if err != ErrPaymentAlreadyCompleted {
t.Fatalf("unable to remark payment hash failed: %v", err)
}
}
// TestPaymentControlNonStrictFailsWithoutInFlight checks that a non-strict
// payment control will allow calls to Fail when no payment is in flight. This
// is necessary to gracefully handle the case in which the switch already sent
// out a payment for a particular payment hash in a prior db version that didn't
// have payment statuses.
func TestPaymentControlNonStrictFailsWithoutInFlight(t *testing.T) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(false, db)
htlc, err := genHtlc()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
if err := pControl.Fail(htlc.PaymentHash); err != nil {
t.Fatalf("unable to mark payment hash failed: %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded)
err = pControl.Fail(htlc.PaymentHash)
if err != nil {
t.Fatalf("unable to remark payment hash failed: %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded)
err = pControl.Success(htlc.PaymentHash)
if err != nil {
t.Fatalf("unable to remark payment hash success: %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted)
err = pControl.Fail(htlc.PaymentHash)
if err != ErrPaymentAlreadyCompleted {
t.Fatalf("unable to remark payment hash failed: %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted)
}
// TestPaymentControlStrictSuccessesWithoutInFlight checks that a strict payment
// control will disallow calls to Success when no payment is in flight.
func TestPaymentControlStrictSuccessesWithoutInFlight(t *testing.T) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(true, db)
htlc, err := genHtlc()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
err = pControl.Success(htlc.PaymentHash)
if err != ErrPaymentNotInitiated {
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded)
}
// TestPaymentControlStrictFailsWithoutInFlight checks that a strict payment
// control will disallow calls to Fail when no payment is in flight.
func TestPaymentControlStrictFailsWithoutInFlight(t *testing.T) {
t.Parallel()
db, err := initDB()
if err != nil {
t.Fatalf("unable to init db: %v", err)
}
pControl := NewPaymentControl(true, db)
htlc, err := genHtlc()
if err != nil {
t.Fatalf("unable to generate htlc message: %v", err)
}
err = pControl.Fail(htlc.PaymentHash)
if err != ErrPaymentNotInitiated {
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
}
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded)
}
func assertPaymentStatus(t *testing.T, db *channeldb.DB,
hash [32]byte, expStatus channeldb.PaymentStatus) {
t.Helper()
pStatus, err := db.FetchPaymentStatus(hash)
if err != nil {
t.Fatalf("unable to fetch payment status: %v", err)
}
if pStatus != expStatus {
t.Fatalf("payment status mismatch: expected %v, got %v",
expStatus, pStatus)
}
}

@ -3755,8 +3755,8 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
n.firstBobChannelLink.ShortChanID(), htlc, n.firstBobChannelLink.ShortChanID(), htlc,
newMockDeobfuscator(), newMockDeobfuscator(),
) )
if err != nil { if err != ErrAlreadyPaid {
t.Fatalf("error shouldn't have been received got: %v", err) t.Fatalf("ErrAlreadyPaid should have been received got: %v", err)
} }
} }

@ -124,14 +124,25 @@ type mockServer struct {
var _ lnpeer.Peer = (*mockServer)(nil) var _ lnpeer.Peer = (*mockServer)(nil)
func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) { func initDB() (*channeldb.DB, error) {
if db == nil { tempPath, err := ioutil.TempDir("", "switchdb")
tempPath, err := ioutil.TempDir("", "switchdb") if err != nil {
if err != nil { return nil, err
return nil, err }
}
db, err = channeldb.Open(tempPath) db, err := channeldb.Open(tempPath)
if err != nil {
return nil, err
}
return db, err
}
func initSwitchWithDB(startingHeight uint32, db *channeldb.DB) (*Switch, error) {
var err error
if db == nil {
db, err = initDB()
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -10,11 +10,10 @@ import (
"time" "time"
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
@ -47,6 +46,11 @@ var (
// txn. // txn.
ErrIncompleteForward = errors.New("incomplete forward detected") ErrIncompleteForward = errors.New("incomplete forward detected")
// ErrUnknownErrorDecryptor signals that we were unable to locate the
// error decryptor for this payment. This is likely due to restarting
// the daemon.
ErrUnknownErrorDecryptor = errors.New("unknown error decryptor")
// ErrSwitchExiting signaled when the switch has received a shutdown // ErrSwitchExiting signaled when the switch has received a shutdown
// request. // request.
ErrSwitchExiting = errors.New("htlcswitch shutting down") ErrSwitchExiting = errors.New("htlcswitch shutting down")
@ -64,7 +68,6 @@ type pendingPayment struct {
amount lnwire.MilliSatoshi amount lnwire.MilliSatoshi
preimage chan [sha256.Size]byte preimage chan [sha256.Size]byte
response chan *htlcPacket
err chan error err chan error
// deobfuscator is a serializable entity which is used if we received // deobfuscator is a serializable entity which is used if we received
@ -204,6 +207,9 @@ type Switch struct {
paymentSequencer Sequencer paymentSequencer Sequencer
// control provides verification of sending htlc mesages
control ControlTower
// circuits is storage for payment circuits which are used to // circuits is storage for payment circuits which are used to
// forward the settle/fail htlc updates back to the add htlc initiator. // forward the settle/fail htlc updates back to the add htlc initiator.
circuits CircuitMap circuits CircuitMap
@ -289,6 +295,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {
cfg: &cfg, cfg: &cfg,
circuits: circuitMap, circuits: circuitMap,
paymentSequencer: sequencer, paymentSequencer: sequencer,
control: NewPaymentControl(false, cfg.DB),
linkIndex: make(map[lnwire.ChannelID]ChannelLink), linkIndex: make(map[lnwire.ChannelID]ChannelLink),
mailOrchestrator: newMailOrchestrator(), mailOrchestrator: newMailOrchestrator(),
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
@ -344,11 +351,17 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID,
htlc *lnwire.UpdateAddHTLC, htlc *lnwire.UpdateAddHTLC,
deobfuscator ErrorDecrypter) ([sha256.Size]byte, error) { deobfuscator ErrorDecrypter) ([sha256.Size]byte, error) {
// Before sending, double check that we don't already have 1) an
// in-flight payment to this payment hash, or 2) a complete payment for
// the same hash.
if err := s.control.ClearForTakeoff(htlc); err != nil {
return zeroPreimage, err
}
// Create payment and add to the map of payment in order later to be // Create payment and add to the map of payment in order later to be
// able to retrieve it and return response to the user. // able to retrieve it and return response to the user.
payment := &pendingPayment{ payment := &pendingPayment{
err: make(chan error, 1), err: make(chan error, 1),
response: make(chan *htlcPacket, 1),
preimage: make(chan [sha256.Size]byte, 1), preimage: make(chan [sha256.Size]byte, 1),
paymentHash: htlc.PaymentHash, paymentHash: htlc.PaymentHash,
amount: htlc.Amount, amount: htlc.Amount,
@ -376,13 +389,16 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID,
if err := s.forward(packet); err != nil { if err := s.forward(packet); err != nil {
s.removePendingPayment(paymentID) s.removePendingPayment(paymentID)
if err := s.control.Fail(htlc.PaymentHash); err != nil {
return zeroPreimage, err
}
return zeroPreimage, err return zeroPreimage, err
} }
// Returns channels so that other subsystem might wait/skip the // Returns channels so that other subsystem might wait/skip the
// waiting of handling of payment. // waiting of handling of payment.
var preimage [sha256.Size]byte var preimage [sha256.Size]byte
var response *htlcPacket
select { select {
case e := <-payment.err: case e := <-payment.err:
@ -391,13 +407,6 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID,
return zeroPreimage, ErrSwitchExiting return zeroPreimage, ErrSwitchExiting
} }
select {
case pkt := <-payment.response:
response = pkt
case <-s.quit:
return zeroPreimage, ErrSwitchExiting
}
select { select {
case p := <-payment.preimage: case p := <-payment.preimage:
preimage = p preimage = p
@ -405,24 +414,6 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID,
return zeroPreimage, ErrSwitchExiting return zeroPreimage, ErrSwitchExiting
} }
// Remove circuit since we are about to complete an add/fail of this
// HTLC.
if teardownErr := s.teardownCircuit(response); teardownErr != nil {
log.Warnf("unable to teardown circuit %s: %v",
response.inKey(), teardownErr)
return preimage, err
}
// Finally, if this response is contained in a forwarding package, ack
// the settle/fail so that we don't continue to retransmit the HTLC
// internally.
if response.destRef != nil {
if ackErr := s.ackSettleFail(*response.destRef); ackErr != nil {
log.Warnf("unable to ack settle/fail reference: %s: %v",
*response.destRef, ackErr)
}
}
return preimage, err return preimage, err
} }
@ -770,20 +761,10 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error,
// Alice Bob Carol // Alice Bob Carol
// //
func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
// Pending payments use a special interpretation of the incomingChanID and
// incomingHTLCID fields on packet where the channel ID is blank and the
// HTLC ID is the payment ID. The switch basically views the users of the
// node as a special channel that also offers a sequence of HTLCs.
payment, err := s.findPayment(pkt.incomingHTLCID)
if err != nil {
return err
}
switch htlc := pkt.htlc.(type) {
// User have created the htlc update therefore we should find the // User have created the htlc update therefore we should find the
// appropriate channel link and send the payment over this link. // appropriate channel link and send the payment over this link.
case *lnwire.UpdateAddHTLC: if htlc, ok := pkt.htlc.(*lnwire.UpdateAddHTLC); ok {
// Try to find links by node destination.
s.indexMtx.RLock() s.indexMtx.RLock()
link, err := s.getLinkByShortID(pkt.outgoingChanID) link, err := s.getLinkByShortID(pkt.outgoingChanID)
s.indexMtx.RUnlock() s.indexMtx.RUnlock()
@ -827,31 +808,113 @@ func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error {
} }
return link.HandleSwitchPacket(pkt) return link.HandleSwitchPacket(pkt)
// We've just received a settle update which means we can finalize the
// user payment and return successful response.
case *lnwire.UpdateFulfillHTLC:
// Notify the user that his payment was successfully proceed.
payment.err <- nil
payment.response <- pkt
payment.preimage <- htlc.PaymentPreimage
s.removePendingPayment(pkt.incomingHTLCID)
// We've just received a fail update which means we can finalize the
// user payment and return fail response.
case *lnwire.UpdateFailHTLC:
payment.err <- s.parseFailedPayment(payment, pkt, htlc)
payment.response <- pkt
payment.preimage <- zeroPreimage
s.removePendingPayment(pkt.incomingHTLCID)
default:
return errors.New("wrong update type")
} }
s.wg.Add(1)
go s.handleLocalResponse(pkt)
return nil return nil
} }
// handleLocalResponse processes a Settle or Fail responding to a
// locally-initiated payment. This is handled asynchronously to avoid blocking
// the main event loop within the switch, as these operations can require
// multiple db transactions. The guarantees of the circuit map are stringent
// enough such that we are able to tolerate reordering of these operations
// without side effects. The primary operations handled are:
// 1. Ack settle/fail references, to avoid resending this response internally
// 2. Teardown the closing circuit in the circuit map
// 3. Transition the payment status to grounded or completed.
// 4. Respond to an in-mem pending payment, if it is found.
//
// NOTE: This method MUST be spawned as a goroutine.
func (s *Switch) handleLocalResponse(pkt *htlcPacket) {
defer s.wg.Done()
// First, we'll clean up any fwdpkg references, circuit entries, and
// mark in our db that the payment for this payment hash has either
// succeeded or failed.
//
// If this response is contained in a forwarding package, we'll start by
// acking the settle/fail so that we don't continue to retransmit the
// HTLC internally.
if pkt.destRef != nil {
if err := s.ackSettleFail(*pkt.destRef); err != nil {
log.Warnf("Unable to ack settle/fail reference: %s: %v",
*pkt.destRef, err)
return
}
}
// Next, we'll remove the circuit since we are about to complete an
// fulfill/fail of this HTLC. Since we've already removed the
// settle/fail fwdpkg reference, the response from the peer cannot be
// replayed internally if this step fails. If this happens, this logic
// will be executed when a provided resolution message comes through.
// This can only happen if the circuit is still open, which is why this
// ordering is chosen.
if err := s.teardownCircuit(pkt); err != nil {
log.Warnf("Unable to teardown circuit %s: %v",
pkt.inKey(), err)
return
}
// Locate the pending payment to notify the application that this
// payment has failed. If one is not found, it likely means the daemon
// has been restarted since sending the payment.
payment := s.findPayment(pkt.incomingHTLCID)
var (
preimage [32]byte
paymentErr error
)
switch htlc := pkt.htlc.(type) {
// We've received a settle update which means we can finalize the user
// payment and return successful response.
case *lnwire.UpdateFulfillHTLC:
// Persistently mark that a payment to this payment hash
// succeeded. This will prevent us from ever making another
// payment to this hash.
err := s.control.Success(pkt.circuit.PaymentHash)
if err != nil && err != ErrPaymentAlreadyCompleted {
log.Warnf("Unable to mark completed payment %x: %v",
pkt.circuit.PaymentHash, err)
return
}
preimage = htlc.PaymentPreimage
// We've received a fail update which means we can finalize the user
// payment and return fail response.
case *lnwire.UpdateFailHTLC:
// Persistently mark that a payment to this payment hash failed.
// This will permit us to make another attempt at a successful
// payment.
err := s.control.Fail(pkt.circuit.PaymentHash)
if err != nil && err != ErrPaymentAlreadyCompleted {
log.Warnf("Unable to ground payment %x: %v",
pkt.circuit.PaymentHash, err)
return
}
paymentErr = s.parseFailedPayment(payment, pkt, htlc)
default:
log.Warnf("Received unknown response type: %T", pkt.htlc)
return
}
// Deliver the payment error and preimage to the application, if it is
// waiting for a response.
if payment != nil {
payment.err <- paymentErr
payment.preimage <- preimage
s.removePendingPayment(pkt.incomingHTLCID)
}
}
// parseFailedPayment determines the appropriate failure message to return to // parseFailedPayment determines the appropriate failure message to return to
// a user initiated payment. The three cases handled are: // a user initiated payment. The three cases handled are:
// 1) A local failure, which should already plaintext. // 1) A local failure, which should already plaintext.
@ -874,7 +937,8 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket,
failureMsg, err := lnwire.DecodeFailure(r, 0) failureMsg, err := lnwire.DecodeFailure(r, 0)
if err != nil { if err != nil {
userErr = fmt.Sprintf("unable to decode onion failure, "+ userErr = fmt.Sprintf("unable to decode onion failure, "+
"htlc with hash(%x): %v", payment.paymentHash[:], err) "htlc with hash(%x): %v",
pkt.circuit.PaymentHash[:], err)
log.Error(userErr) log.Error(userErr)
// As this didn't even clear the link, we don't need to // As this didn't even clear the link, we don't need to
@ -901,6 +965,18 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket,
FailureMessage: lnwire.FailPermanentChannelFailure{}, FailureMessage: lnwire.FailPermanentChannelFailure{},
} }
// If the provided payment is nil, we have discarded the error decryptor
// due to a restart. We'll return a fixed error and signal a temporary
// channel failure to the router.
case payment == nil:
userErr := fmt.Sprintf("error decryptor for payment " +
"could not be located, likely due to restart")
failure = &ForwardingError{
ErrorSource: s.cfg.SelfKey,
ExtraMsg: userErr,
FailureMessage: lnwire.NewTemporaryChannelFailure(nil),
}
// A regular multi-hop payment error that we'll need to // A regular multi-hop payment error that we'll need to
// decrypt. // decrypt.
default: default:
@ -909,8 +985,9 @@ func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket,
// error. If we're unable to then we'll bail early. // error. If we're unable to then we'll bail early.
failure, err = payment.deobfuscator.DecryptError(htlc.Reason) failure, err = payment.deobfuscator.DecryptError(htlc.Reason)
if err != nil { if err != nil {
userErr := fmt.Sprintf("unable to de-obfuscate onion failure, "+ userErr := fmt.Sprintf("unable to de-obfuscate onion "+
"htlc with hash(%x): %v", payment.paymentHash[:], err) "failure, htlc with hash(%x): %v",
pkt.circuit.PaymentHash[:], err)
log.Error(userErr) log.Error(userErr)
failure = &ForwardingError{ failure = &ForwardingError{
ErrorSource: s.cfg.SelfKey, ErrorSource: s.cfg.SelfKey,
@ -2042,30 +2119,26 @@ func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
// removePendingPayment is the helper function which removes the pending user // removePendingPayment is the helper function which removes the pending user
// payment. // payment.
func (s *Switch) removePendingPayment(paymentID uint64) error { func (s *Switch) removePendingPayment(paymentID uint64) {
s.pendingMutex.Lock() s.pendingMutex.Lock()
defer s.pendingMutex.Unlock() defer s.pendingMutex.Unlock()
if _, ok := s.pendingPayments[paymentID]; !ok {
return fmt.Errorf("Cannot find pending payment with ID %d",
paymentID)
}
delete(s.pendingPayments, paymentID) delete(s.pendingPayments, paymentID)
return nil
} }
// findPayment is the helper function which find the payment. // findPayment is the helper function which find the payment.
func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) { func (s *Switch) findPayment(paymentID uint64) *pendingPayment {
s.pendingMutex.RLock() s.pendingMutex.RLock()
defer s.pendingMutex.RUnlock() defer s.pendingMutex.RUnlock()
payment, ok := s.pendingPayments[paymentID] payment, ok := s.pendingPayments[paymentID]
if !ok { if !ok {
return nil, fmt.Errorf("Cannot find pending payment with ID %d", log.Errorf("Cannot find pending payment with ID %d",
paymentID) paymentID)
return nil
} }
return payment, nil
return payment
} }
// CircuitModifier returns a reference to subset of the interfaces provided by // CircuitModifier returns a reference to subset of the interfaces provided by
@ -2077,6 +2150,9 @@ func (s *Switch) CircuitModifier() CircuitModifier {
// numPendingPayments is helper function which returns the overall number of // numPendingPayments is helper function which returns the overall number of
// pending user payments. // pending user payments.
func (s *Switch) numPendingPayments() int { func (s *Switch) numPendingPayments() int {
s.pendingMutex.RLock()
defer s.pendingMutex.RUnlock()
return len(s.pendingPayments) return len(s.pendingPayments)
} }

@ -1678,7 +1678,9 @@ func TestSwitchSendPayment(t *testing.T) {
} }
case err := <-errChan: case err := <-errChan:
t.Fatalf("unable to send payment: %v", err) if err != ErrPaymentInFlight {
t.Fatalf("unable to send payment: %v", err)
}
case <-time.After(time.Second): case <-time.After(time.Second):
t.Fatal("request was not propagated to destination") t.Fatal("request was not propagated to destination")
} }
@ -1695,11 +1697,11 @@ func TestSwitchSendPayment(t *testing.T) {
t.Fatal("request was not propagated to destination") t.Fatal("request was not propagated to destination")
} }
if s.numPendingPayments() != 2 { if s.numPendingPayments() != 1 {
t.Fatal("wrong amount of pending payments") t.Fatal("wrong amount of pending payments")
} }
if s.circuits.NumOpen() != 2 { if s.circuits.NumOpen() != 1 {
t.Fatal("wrong amount of circuits") t.Fatal("wrong amount of circuits")
} }
@ -1735,29 +1737,6 @@ func TestSwitchSendPayment(t *testing.T) {
t.Fatal("err wasn't received") t.Fatal("err wasn't received")
} }
packet = &htlcPacket{
outgoingChanID: aliceChannelLink.ShortChanID(),
outgoingHTLCID: 1,
htlc: &lnwire.UpdateFailHTLC{
Reason: reason,
},
}
// Send second failure response and check that user were able to
// receive the error.
if err := s.forward(packet); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
select {
case err := <-errChan:
if err.Error() != errors.New(lnwire.CodeIncorrectPaymentAmount).Error() {
t.Fatal("err wasn't received")
}
case <-time.After(time.Second):
t.Fatal("err wasn't received")
}
if s.numPendingPayments() != 0 { if s.numPendingPayments() != 0 {
t.Fatal("wrong amount of pending payments") t.Fatal("wrong amount of pending payments")
} }

@ -450,6 +450,17 @@ func completePaymentRequests(ctx context.Context, client lnrpc.LightningClient,
return nil return nil
} }
// makeFakePayHash creates random pre image hash
func makeFakePayHash(t *harnessTest) []byte {
randBuf := make([]byte, 32)
if _, err := rand.Read(randBuf); err != nil {
t.Fatalf("internal error, cannot generate random string: %v", err)
}
return randBuf
}
const ( const (
AddrTypeWitnessPubkeyHash = lnrpc.NewAddressRequest_WITNESS_PUBKEY_HASH AddrTypeWitnessPubkeyHash = lnrpc.NewAddressRequest_WITNESS_PUBKEY_HASH
AddrTypeNestedPubkeyHash = lnrpc.NewAddressRequest_NESTED_PUBKEY_HASH AddrTypeNestedPubkeyHash = lnrpc.NewAddressRequest_NESTED_PUBKEY_HASH
@ -1822,13 +1833,13 @@ func testChannelForceClosure(net *lntest.NetworkHarness, t *harnessTest) {
if err != nil { if err != nil {
t.Fatalf("unable to create payment stream for alice: %v", err) t.Fatalf("unable to create payment stream for alice: %v", err)
} }
carolPubKey := carol.PubKey[:] carolPubKey := carol.PubKey[:]
payHash := bytes.Repeat([]byte{2}, 32)
for i := 0; i < numInvoices; i++ { for i := 0; i < numInvoices; i++ {
err = alicePayStream.Send(&lnrpc.SendRequest{ err = alicePayStream.Send(&lnrpc.SendRequest{
Dest: carolPubKey, Dest: carolPubKey,
Amt: int64(paymentAmt), Amt: int64(paymentAmt),
PaymentHash: payHash, PaymentHash: makeFakePayHash(t),
FinalCltvDelta: defaultBitcoinTimeLockDelta, FinalCltvDelta: defaultBitcoinTimeLockDelta,
}) })
if err != nil { if err != nil {
@ -3945,9 +3956,16 @@ func testPrivateChannels(net *lntest.NetworkHarness, t *harnessTest) {
const paymentAmt = 70000 const paymentAmt = 70000
payReqs := make([]string, numPayments) payReqs := make([]string, numPayments)
for i := 0; i < numPayments; i++ { for i := 0; i < numPayments; i++ {
preimage := make([]byte, 32)
_, err := rand.Read(preimage)
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
invoice := &lnrpc.Invoice{ invoice := &lnrpc.Invoice{
Memo: "testing", Memo: "testing",
Value: paymentAmt, RPreimage: preimage,
Value: paymentAmt,
} }
resp, err := net.Bob.AddInvoice(ctxb, invoice) resp, err := net.Bob.AddInvoice(ctxb, invoice)
if err != nil { if err != nil {
@ -4008,9 +4026,16 @@ func testPrivateChannels(net *lntest.NetworkHarness, t *harnessTest) {
const paymentAmt60k = 60000 const paymentAmt60k = 60000
payReqs = make([]string, numPayments) payReqs = make([]string, numPayments)
for i := 0; i < numPayments; i++ { for i := 0; i < numPayments; i++ {
preimage := make([]byte, 32)
_, err := rand.Read(preimage)
if err != nil {
t.Fatalf("unable to generate preimage: %v", err)
}
invoice := &lnrpc.Invoice{ invoice := &lnrpc.Invoice{
Memo: "testing", Memo: "testing",
Value: paymentAmt60k, RPreimage: preimage,
Value: paymentAmt60k,
} }
resp, err := carol.AddInvoice(ctxb, invoice) resp, err := carol.AddInvoice(ctxb, invoice)
if err != nil { if err != nil {
@ -4496,10 +4521,9 @@ func testInvoiceSubscriptions(net *lntest.NetworkHarness, t *harnessTest) {
// TODO(roasbeef): make global list of invoices for each node to re-use // TODO(roasbeef): make global list of invoices for each node to re-use
// and avoid collisions // and avoid collisions
const paymentAmt = 1000 const paymentAmt = 1000
preimage := bytes.Repeat([]byte{byte(90)}, 32)
invoice := &lnrpc.Invoice{ invoice := &lnrpc.Invoice{
Memo: "testing", Memo: "testing",
RPreimage: preimage, RPreimage: makeFakePayHash(t),
Value: paymentAmt, Value: paymentAmt,
} }
invoiceResp, err := net.Bob.AddInvoice(ctxb, invoice) invoiceResp, err := net.Bob.AddInvoice(ctxb, invoice)
@ -6727,7 +6751,7 @@ out:
// stream on payment error. // stream on payment error.
ctxt, _ = context.WithTimeout(ctxb, timeout) ctxt, _ = context.WithTimeout(ctxb, timeout)
sendReq := &lnrpc.SendRequest{ sendReq := &lnrpc.SendRequest{
PaymentHashString: hex.EncodeToString(bytes.Repeat([]byte("Z"), 32)), PaymentHashString: hex.EncodeToString(makeFakePayHash(t)),
DestString: hex.EncodeToString(carol.PubKey[:]), DestString: hex.EncodeToString(carol.PubKey[:]),
Amt: payAmt, Amt: payAmt,
} }
@ -6856,6 +6880,12 @@ out:
"instead: %v", resp.PaymentError) "instead: %v", resp.PaymentError)
} }
// Generate new invoice to not pay same invoice twice.
carolInvoice, err = carol.AddInvoice(ctxb, invoiceReq)
if err != nil {
t.Fatalf("unable to generate carol invoice: %v", err)
}
// For our final test, we'll ensure that if a target link isn't // For our final test, we'll ensure that if a target link isn't
// available for what ever reason then the payment fails accordingly. // available for what ever reason then the payment fails accordingly.
// //
@ -7953,8 +7983,8 @@ func testMultiHopHtlcLocalTimeout(net *lntest.NetworkHarness, t *harnessTest) {
// We'll create two random payment hashes unknown to carol, then send // We'll create two random payment hashes unknown to carol, then send
// each of them by manually specifying the HTLC details. // each of them by manually specifying the HTLC details.
carolPubKey := carol.PubKey[:] carolPubKey := carol.PubKey[:]
dustPayHash := bytes.Repeat([]byte{1}, 32) dustPayHash := makeFakePayHash(t)
payHash := bytes.Repeat([]byte{2}, 32) payHash := makeFakePayHash(t)
err = alicePayStream.Send(&lnrpc.SendRequest{ err = alicePayStream.Send(&lnrpc.SendRequest{
Dest: carolPubKey, Dest: carolPubKey,
Amt: int64(dustHtlcAmt), Amt: int64(dustHtlcAmt),
@ -8412,7 +8442,7 @@ func testMultiHopLocalForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness,
// We'll now send a single HTLC across our multi-hop network. // We'll now send a single HTLC across our multi-hop network.
carolPubKey := carol.PubKey[:] carolPubKey := carol.PubKey[:]
payHash := bytes.Repeat([]byte{2}, 32) payHash := makeFakePayHash(t)
err = alicePayStream.Send(&lnrpc.SendRequest{ err = alicePayStream.Send(&lnrpc.SendRequest{
Dest: carolPubKey, Dest: carolPubKey,
Amt: int64(htlcAmt), Amt: int64(htlcAmt),
@ -8669,7 +8699,7 @@ func testMultiHopRemoteForceCloseOnChainHtlcTimeout(net *lntest.NetworkHarness,
// We'll now send a single HTLC across our multi-hop network. // We'll now send a single HTLC across our multi-hop network.
carolPubKey := carol.PubKey[:] carolPubKey := carol.PubKey[:]
payHash := bytes.Repeat([]byte{2}, 32) payHash := makeFakePayHash(t)
err = alicePayStream.Send(&lnrpc.SendRequest{ err = alicePayStream.Send(&lnrpc.SendRequest{
Dest: carolPubKey, Dest: carolPubKey,
Amt: int64(htlcAmt), Amt: int64(htlcAmt),