htlcswitch+channeldb: move control tower to channeldb

This commit is contained in:
Johan T. Halseth 2019-05-23 20:05:26 +02:00
parent 6e102d64b9
commit d027e10201
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
4 changed files with 76 additions and 85 deletions

@ -1,10 +1,9 @@
package htlcswitch package channeldb
import ( import (
"errors" "errors"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -59,7 +58,7 @@ type ControlTower interface {
type paymentControl struct { type paymentControl struct {
strict bool strict bool
db *channeldb.DB db *DB
} }
// NewPaymentControl creates a new instance of the paymentControl. The strict // NewPaymentControl creates a new instance of the paymentControl. The strict
@ -70,7 +69,7 @@ type paymentControl struct {
// contain such payments. In the meantime, non-strict mode enforces a superset // contain such payments. In the meantime, non-strict mode enforces a superset
// of the state transitions that prevent additional payments to a given payment // of the state transitions that prevent additional payments to a given payment
// hash from being added. // hash from being added.
func NewPaymentControl(strict bool, db *channeldb.DB) ControlTower { func NewPaymentControl(strict bool, db *DB) ControlTower {
return &paymentControl{ return &paymentControl{
strict: strict, strict: strict,
db: db, db: db,
@ -83,7 +82,7 @@ func (p *paymentControl) ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error {
var takeoffErr error var takeoffErr error
err := p.db.Batch(func(tx *bbolt.Tx) error { err := p.db.Batch(func(tx *bbolt.Tx) error {
// Retrieve current status of payment from local database. // Retrieve current status of payment from local database.
paymentStatus, err := channeldb.FetchPaymentStatusTx( paymentStatus, err := FetchPaymentStatusTx(
tx, htlc.PaymentHash, tx, htlc.PaymentHash,
) )
if err != nil { if err != nil {
@ -96,21 +95,21 @@ func (p *paymentControl) ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error {
switch paymentStatus { switch paymentStatus {
case channeldb.StatusGrounded: case StatusGrounded:
// It is safe to reattempt a payment if we know that we // It is safe to reattempt a payment if we know that we
// haven't left one in flight. Since this one is // haven't left one in flight. Since this one is
// grounded, Transition the payment status to InFlight // grounded, Transition the payment status to InFlight
// to prevent others. // to prevent others.
return channeldb.UpdatePaymentStatusTx( return UpdatePaymentStatusTx(
tx, htlc.PaymentHash, channeldb.StatusInFlight, tx, htlc.PaymentHash, StatusInFlight,
) )
case channeldb.StatusInFlight: case StatusInFlight:
// We already have an InFlight payment on the network. We will // We already have an InFlight payment on the network. We will
// disallow any more payment until a response is received. // disallow any more payment until a response is received.
takeoffErr = ErrPaymentInFlight takeoffErr = ErrPaymentInFlight
case channeldb.StatusCompleted: case StatusCompleted:
// We've already completed a payment to this payment hash, // We've already completed a payment to this payment hash,
// forbid the switch from sending another. // forbid the switch from sending another.
takeoffErr = ErrAlreadyPaid takeoffErr = ErrAlreadyPaid
@ -134,7 +133,7 @@ func (p *paymentControl) ClearForTakeoff(htlc *lnwire.UpdateAddHTLC) error {
func (p *paymentControl) Success(paymentHash [32]byte) error { func (p *paymentControl) Success(paymentHash [32]byte) error {
var updateErr error var updateErr error
err := p.db.Batch(func(tx *bbolt.Tx) error { err := p.db.Batch(func(tx *bbolt.Tx) error {
paymentStatus, err := channeldb.FetchPaymentStatusTx( paymentStatus, err := FetchPaymentStatusTx(
tx, paymentHash, tx, paymentHash,
) )
if err != nil { if err != nil {
@ -147,27 +146,27 @@ func (p *paymentControl) Success(paymentHash [32]byte) error {
switch { switch {
case paymentStatus == channeldb.StatusGrounded && p.strict: case paymentStatus == StatusGrounded && p.strict:
// Our records show the payment as still being grounded, // Our records show the payment as still being grounded,
// meaning it never should have left the switch. // meaning it never should have left the switch.
updateErr = ErrPaymentNotInitiated updateErr = ErrPaymentNotInitiated
case paymentStatus == channeldb.StatusGrounded && !p.strict: case paymentStatus == StatusGrounded && !p.strict:
// Though our records show the payment as still being // Though our records show the payment as still being
// grounded, meaning it never should have left the // grounded, meaning it never should have left the
// switch, we permit this transition in non-strict mode // switch, we permit this transition in non-strict mode
// to handle inconsistent db states. // to handle inconsistent db states.
fallthrough fallthrough
case paymentStatus == channeldb.StatusInFlight: case paymentStatus == StatusInFlight:
// A successful response was received for an InFlight // A successful response was received for an InFlight
// payment, mark it as completed to prevent sending to // payment, mark it as completed to prevent sending to
// this payment hash again. // this payment hash again.
return channeldb.UpdatePaymentStatusTx( return UpdatePaymentStatusTx(
tx, paymentHash, channeldb.StatusCompleted, tx, paymentHash, StatusCompleted,
) )
case paymentStatus == channeldb.StatusCompleted: case paymentStatus == StatusCompleted:
// The payment was completed previously, alert the // The payment was completed previously, alert the
// caller that this may be a duplicate call. // caller that this may be a duplicate call.
updateErr = ErrPaymentAlreadyCompleted updateErr = ErrPaymentAlreadyCompleted
@ -191,7 +190,7 @@ func (p *paymentControl) Success(paymentHash [32]byte) error {
func (p *paymentControl) Fail(paymentHash [32]byte) error { func (p *paymentControl) Fail(paymentHash [32]byte) error {
var updateErr error var updateErr error
err := p.db.Batch(func(tx *bbolt.Tx) error { err := p.db.Batch(func(tx *bbolt.Tx) error {
paymentStatus, err := channeldb.FetchPaymentStatusTx( paymentStatus, err := FetchPaymentStatusTx(
tx, paymentHash, tx, paymentHash,
) )
if err != nil { if err != nil {
@ -204,27 +203,27 @@ func (p *paymentControl) Fail(paymentHash [32]byte) error {
switch { switch {
case paymentStatus == channeldb.StatusGrounded && p.strict: case paymentStatus == StatusGrounded && p.strict:
// Our records show the payment as still being grounded, // Our records show the payment as still being grounded,
// meaning it never should have left the switch. // meaning it never should have left the switch.
updateErr = ErrPaymentNotInitiated updateErr = ErrPaymentNotInitiated
case paymentStatus == channeldb.StatusGrounded && !p.strict: case paymentStatus == StatusGrounded && !p.strict:
// Though our records show the payment as still being // Though our records show the payment as still being
// grounded, meaning it never should have left the // grounded, meaning it never should have left the
// switch, we permit this transition in non-strict mode // switch, we permit this transition in non-strict mode
// to handle inconsistent db states. // to handle inconsistent db states.
fallthrough fallthrough
case paymentStatus == channeldb.StatusInFlight: case paymentStatus == StatusInFlight:
// A failed response was received for an InFlight // A failed response was received for an InFlight
// payment, mark it as Grounded again to allow // payment, mark it as Grounded again to allow
// subsequent attempts. // subsequent attempts.
return channeldb.UpdatePaymentStatusTx( return UpdatePaymentStatusTx(
tx, paymentHash, channeldb.StatusGrounded, tx, paymentHash, StatusGrounded,
) )
case paymentStatus == channeldb.StatusCompleted: case paymentStatus == StatusCompleted:
// The payment was completed previously, and we are now // The payment was completed previously, and we are now
// reporting that it has failed. Leave the status as // reporting that it has failed. Leave the status as
// completed, but alert the user that something is // completed, but alert the user that something is

@ -1,14 +1,38 @@
package htlcswitch package channeldb
import ( import (
"crypto/rand"
"fmt" "fmt"
"io"
"io/ioutil"
"testing" "testing"
"github.com/btcsuite/fastsha256" "github.com/btcsuite/fastsha256"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
func initDB() (*DB, error) {
tempPath, err := ioutil.TempDir("", "switchdb")
if err != nil {
return nil, err
}
db, err := Open(tempPath)
if err != nil {
return nil, err
}
return db, err
}
func genPreimage() ([32]byte, error) {
var preimage [32]byte
if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil {
return preimage, err
}
return preimage, nil
}
func genHtlc() (*lnwire.UpdateAddHTLC, error) { func genHtlc() (*lnwire.UpdateAddHTLC, error) {
preimage, err := genPreimage() preimage, err := genPreimage()
if err != nil { if err != nil {
@ -99,7 +123,7 @@ func testPaymentControlSwitchFail(t *testing.T, strict bool) {
t.Fatalf("unable to send htlc message: %v", err) t.Fatalf("unable to send htlc message: %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) assertPaymentStatus(t, db, htlc.PaymentHash, StatusInFlight)
// Fail the payment, which should moved it to Grounded. // Fail the payment, which should moved it to Grounded.
if err := pControl.Fail(htlc.PaymentHash); err != nil { if err := pControl.Fail(htlc.PaymentHash); err != nil {
@ -107,7 +131,7 @@ func testPaymentControlSwitchFail(t *testing.T, strict bool) {
} }
// Verify the status is indeed Grounded. // Verify the status is indeed Grounded.
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) assertPaymentStatus(t, db, htlc.PaymentHash, StatusGrounded)
// Sends the htlc again, which should succeed since the prior payment // Sends the htlc again, which should succeed since the prior payment
// failed. // failed.
@ -115,14 +139,14 @@ func testPaymentControlSwitchFail(t *testing.T, strict bool) {
t.Fatalf("unable to send htlc message: %v", err) t.Fatalf("unable to send htlc message: %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) assertPaymentStatus(t, db, htlc.PaymentHash, StatusInFlight)
// Verifies that status was changed to StatusCompleted. // Verifies that status was changed to StatusCompleted.
if err := pControl.Success(htlc.PaymentHash); err != nil { if err := pControl.Success(htlc.PaymentHash); err != nil {
t.Fatalf("error shouldn't have been received, got: %v", err) t.Fatalf("error shouldn't have been received, got: %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) assertPaymentStatus(t, db, htlc.PaymentHash, StatusCompleted)
// Attempt a final payment, which should now fail since the prior // Attempt a final payment, which should now fail since the prior
// payment succeed. // payment succeed.
@ -154,7 +178,7 @@ func testPaymentControlSwitchDoubleSend(t *testing.T, strict bool) {
t.Fatalf("unable to send htlc message: %v", err) t.Fatalf("unable to send htlc message: %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) assertPaymentStatus(t, db, htlc.PaymentHash, StatusInFlight)
// Try to initiate double sending of htlc message with the same // Try to initiate double sending of htlc message with the same
// payment hash, should result in error indicating that payment has // payment hash, should result in error indicating that payment has
@ -188,7 +212,7 @@ func testPaymentControlSwitchDoublePay(t *testing.T, strict bool) {
} }
// Verify that payment is InFlight. // Verify that payment is InFlight.
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusInFlight) assertPaymentStatus(t, db, htlc.PaymentHash, StatusInFlight)
// Move payment to completed status, second payment should return error. // Move payment to completed status, second payment should return error.
if err := pControl.Success(htlc.PaymentHash); err != nil { if err := pControl.Success(htlc.PaymentHash); err != nil {
@ -196,7 +220,7 @@ func testPaymentControlSwitchDoublePay(t *testing.T, strict bool) {
} }
// Verify that payment is Completed. // Verify that payment is Completed.
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) assertPaymentStatus(t, db, htlc.PaymentHash, StatusCompleted)
if err := pControl.ClearForTakeoff(htlc); err != ErrAlreadyPaid { if err := pControl.ClearForTakeoff(htlc); err != ErrAlreadyPaid {
t.Fatalf("payment control wrong behaviour:" + t.Fatalf("payment control wrong behaviour:" +
@ -228,7 +252,7 @@ func TestPaymentControlNonStrictSuccessesWithoutInFlight(t *testing.T) {
t.Fatalf("unable to mark payment hash success: %v", err) t.Fatalf("unable to mark payment hash success: %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) assertPaymentStatus(t, db, htlc.PaymentHash, StatusCompleted)
err = pControl.Success(htlc.PaymentHash) err = pControl.Success(htlc.PaymentHash)
if err != ErrPaymentAlreadyCompleted { if err != ErrPaymentAlreadyCompleted {
@ -260,28 +284,28 @@ func TestPaymentControlNonStrictFailsWithoutInFlight(t *testing.T) {
t.Fatalf("unable to mark payment hash failed: %v", err) t.Fatalf("unable to mark payment hash failed: %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) assertPaymentStatus(t, db, htlc.PaymentHash, StatusGrounded)
err = pControl.Fail(htlc.PaymentHash) err = pControl.Fail(htlc.PaymentHash)
if err != nil { if err != nil {
t.Fatalf("unable to remark payment hash failed: %v", err) t.Fatalf("unable to remark payment hash failed: %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) assertPaymentStatus(t, db, htlc.PaymentHash, StatusGrounded)
err = pControl.Success(htlc.PaymentHash) err = pControl.Success(htlc.PaymentHash)
if err != nil { if err != nil {
t.Fatalf("unable to remark payment hash success: %v", err) t.Fatalf("unable to remark payment hash success: %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) assertPaymentStatus(t, db, htlc.PaymentHash, StatusCompleted)
err = pControl.Fail(htlc.PaymentHash) err = pControl.Fail(htlc.PaymentHash)
if err != ErrPaymentAlreadyCompleted { if err != ErrPaymentAlreadyCompleted {
t.Fatalf("unable to remark payment hash failed: %v", err) t.Fatalf("unable to remark payment hash failed: %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusCompleted) assertPaymentStatus(t, db, htlc.PaymentHash, StatusCompleted)
} }
// TestPaymentControlStrictSuccessesWithoutInFlight checks that a strict payment // TestPaymentControlStrictSuccessesWithoutInFlight checks that a strict payment
@ -306,7 +330,7 @@ func TestPaymentControlStrictSuccessesWithoutInFlight(t *testing.T) {
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) assertPaymentStatus(t, db, htlc.PaymentHash, StatusGrounded)
} }
// TestPaymentControlStrictFailsWithoutInFlight checks that a strict payment // TestPaymentControlStrictFailsWithoutInFlight checks that a strict payment
@ -331,11 +355,11 @@ func TestPaymentControlStrictFailsWithoutInFlight(t *testing.T) {
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
} }
assertPaymentStatus(t, db, htlc.PaymentHash, channeldb.StatusGrounded) assertPaymentStatus(t, db, htlc.PaymentHash, StatusGrounded)
} }
func assertPaymentStatus(t *testing.T, db *channeldb.DB, func assertPaymentStatus(t *testing.T, db *DB,
hash [32]byte, expStatus channeldb.PaymentStatus) { hash [32]byte, expStatus PaymentStatus) {
t.Helper() t.Helper()

@ -3889,8 +3889,7 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
} }
// With the invoice now added to Carol's registry, we'll send the // With the invoice now added to Carol's registry, we'll send the
// payment. It should succeed w/o any issues as it has been crafted // payment.
// properly.
err = n.aliceServer.htlcSwitch.SendHTLC( err = n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc, n.firstBobChannelLink.ShortChanID(), pid, htlc,
) )
@ -3905,6 +3904,16 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
t.Fatalf("unable to get payment result: %v", err) t.Fatalf("unable to get payment result: %v", err)
} }
// Now, if we attempt to send the payment *again* it should be rejected
// as it's a duplicate request.
err = n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc,
)
if err != ErrPaymentIDAlreadyExists {
t.Fatalf("ErrPaymentIDAlreadyExists should have been "+
"received got: %v", err)
}
select { select {
case result, ok := <-resultChan: case result, ok := <-resultChan:
if !ok { if !ok {
@ -3917,15 +3926,6 @@ func TestChannelLinkAcceptDuplicatePayment(t *testing.T) {
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
t.Fatalf("payment result did not arrive") t.Fatalf("payment result did not arrive")
} }
// Now, if we attempt to send the payment *again* it should be rejected
// as it's a duplicate request.
err = n.aliceServer.htlcSwitch.SendHTLC(
n.firstBobChannelLink.ShortChanID(), pid, htlc,
)
if err != ErrAlreadyPaid {
t.Fatalf("ErrAlreadyPaid should have been received got: %v", err)
}
} }
// TestChannelLinkAcceptOverpay tests that if we create an invoice for sender, // TestChannelLinkAcceptOverpay tests that if we create an invoice for sender,

@ -208,9 +208,6 @@ type Switch struct {
pendingPayments map[uint64]*pendingPayment pendingPayments map[uint64]*pendingPayment
pendingMutex sync.RWMutex pendingMutex sync.RWMutex
// control provides verification of sending htlc mesages
control ControlTower
// circuits is storage for payment circuits which are used to // circuits is storage for payment circuits which are used to
// forward the settle/fail htlc updates back to the add htlc initiator. // forward the settle/fail htlc updates back to the add htlc initiator.
circuits CircuitMap circuits CircuitMap
@ -290,7 +287,6 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {
bestHeight: currentHeight, bestHeight: currentHeight,
cfg: &cfg, cfg: &cfg,
circuits: circuitMap, circuits: circuitMap,
control: NewPaymentControl(false, cfg.DB),
linkIndex: make(map[lnwire.ChannelID]ChannelLink), linkIndex: make(map[lnwire.ChannelID]ChannelLink),
mailOrchestrator: newMailOrchestrator(), mailOrchestrator: newMailOrchestrator(),
forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink),
@ -402,13 +398,6 @@ func (s *Switch) GetPaymentResult(paymentID uint64,
func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64, func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
htlc *lnwire.UpdateAddHTLC) error { htlc *lnwire.UpdateAddHTLC) error {
// Before sending, double check that we don't already have 1) an
// in-flight payment to this payment hash, or 2) a complete payment for
// the same hash.
if err := s.control.ClearForTakeoff(htlc); err != nil {
return err
}
// Create payment and add to the map of payment in order later to be // Create payment and add to the map of payment in order later to be
// able to retrieve it and return response to the user. // able to retrieve it and return response to the user.
payment := &pendingPayment{ payment := &pendingPayment{
@ -439,10 +428,6 @@ func (s *Switch) SendHTLC(firstHop lnwire.ShortChannelID, paymentID uint64,
if err := s.forward(packet); err != nil { if err := s.forward(packet); err != nil {
s.removePendingPayment(paymentID) s.removePendingPayment(paymentID)
if err := s.control.Fail(htlc.PaymentHash); err != nil {
return err
}
return err return err
} }
@ -939,15 +924,6 @@ func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult,
// We've received a settle update which means we can finalize the user // We've received a settle update which means we can finalize the user
// payment and return successful response. // payment and return successful response.
case *lnwire.UpdateFulfillHTLC: case *lnwire.UpdateFulfillHTLC:
// Persistently mark that a payment to this payment hash
// succeeded. This will prevent us from ever making another
// payment to this hash.
err := s.control.Success(paymentHash)
if err != nil && err != ErrPaymentAlreadyCompleted {
return nil, fmt.Errorf("Unable to mark completed "+
"payment %x: %v", paymentHash, err)
}
return &PaymentResult{ return &PaymentResult{
Preimage: htlc.PaymentPreimage, Preimage: htlc.PaymentPreimage,
}, nil }, nil
@ -955,14 +931,6 @@ func (s *Switch) extractResult(deobfuscator ErrorDecrypter, n *networkResult,
// We've received a fail update which means we can finalize the // We've received a fail update which means we can finalize the
// user payment and return fail response. // user payment and return fail response.
case *lnwire.UpdateFailHTLC: case *lnwire.UpdateFailHTLC:
// Persistently mark that a payment to this payment hash
// failed. This will permit us to make another attempt at a
// successful payment.
err := s.control.Fail(paymentHash)
if err != nil && err != ErrPaymentAlreadyCompleted {
return nil, fmt.Errorf("Unable to ground payment "+
"%x: %v", paymentHash, err)
}
paymentErr := s.parseFailedPayment( paymentErr := s.parseFailedPayment(
deobfuscator, paymentID, paymentHash, n.unencrypted, deobfuscator, paymentID, paymentHash, n.unencrypted,
n.isResolution, htlc, n.isResolution, htlc,