lnwallet: prevent static fee estimator fees from being modified

Modifying the static fees is not thread safe. In this commit the fees
are made immutable.
This commit is contained in:
Joost Jager 2018-12-18 09:02:27 +01:00
parent 423dd8ab9b
commit 91f3df07e4
No known key found for this signature in database
GPG Key ID: AE6B0D042C8E38D9
12 changed files with 41 additions and 32 deletions

@ -1336,7 +1336,7 @@ func createTestArbiter(t *testing.T, contractBreaches chan *ContractBreachEvent,
ba := newBreachArbiter(&BreachConfig{ ba := newBreachArbiter(&BreachConfig{
CloseLink: func(_ *wire.OutPoint, _ htlcswitch.ChannelCloseType) {}, CloseLink: func(_ *wire.OutPoint, _ htlcswitch.ChannelCloseType) {},
DB: db, DB: db,
Estimator: &lnwallet.StaticFeeEstimator{FeePerKW: 12500}, Estimator: lnwallet.NewStaticFeeEstimator(12500, 0),
GenSweepScript: func() ([]byte, error) { return nil, nil }, GenSweepScript: func() ([]byte, error) { return nil, nil },
ContractBreaches: contractBreaches, ContractBreaches: contractBreaches,
Signer: signer, Signer: signer,
@ -1476,7 +1476,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
return nil, nil, nil, err return nil, nil, nil, err
} }
estimator := &lnwallet.StaticFeeEstimator{FeePerKW: 12500} estimator := lnwallet.NewStaticFeeEstimator(12500, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err

@ -150,9 +150,9 @@ func newChainControlFromConfig(cfg *config, chanDB *channeldb.DB,
FeeRate: cfg.Bitcoin.FeeRate, FeeRate: cfg.Bitcoin.FeeRate,
TimeLockDelta: cfg.Bitcoin.TimeLockDelta, TimeLockDelta: cfg.Bitcoin.TimeLockDelta,
} }
cc.feeEstimator = lnwallet.StaticFeeEstimator{ cc.feeEstimator = lnwallet.NewStaticFeeEstimator(
FeePerKW: defaultBitcoinStaticFeePerKW, defaultBitcoinStaticFeePerKW, 0,
} )
case litecoinChain: case litecoinChain:
cc.routingPolicy = htlcswitch.ForwardingPolicy{ cc.routingPolicy = htlcswitch.ForwardingPolicy{
MinHTLC: cfg.Litecoin.MinHTLC, MinHTLC: cfg.Litecoin.MinHTLC,
@ -160,9 +160,9 @@ func newChainControlFromConfig(cfg *config, chanDB *channeldb.DB,
FeeRate: cfg.Litecoin.FeeRate, FeeRate: cfg.Litecoin.FeeRate,
TimeLockDelta: cfg.Litecoin.TimeLockDelta, TimeLockDelta: cfg.Litecoin.TimeLockDelta,
} }
cc.feeEstimator = lnwallet.StaticFeeEstimator{ cc.feeEstimator = lnwallet.NewStaticFeeEstimator(
FeePerKW: defaultLitecoinStaticFeePerKW, defaultLitecoinStaticFeePerKW, 0,
} )
default: default:
return nil, nil, fmt.Errorf("Default routing policy for "+ return nil, nil, fmt.Errorf("Default routing policy for "+
"chain %v is unknown", registeredChains.PrimaryChain()) "chain %v is unknown", registeredChains.PrimaryChain())

@ -231,7 +231,7 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey,
addr *lnwire.NetAddress, tempTestDir string) (*testNode, error) { addr *lnwire.NetAddress, tempTestDir string) (*testNode, error) {
netParams := activeNetParams.Params netParams := activeNetParams.Params
estimator := lnwallet.StaticFeeEstimator{FeePerKW: 62500} estimator := lnwallet.NewStaticFeeEstimator(62500, 0)
chainNotifier := &mockNotifier{ chainNotifier := &mockNotifier{
oneConfChannel: make(chan *chainntnfs.TxConfirmation, 1), oneConfChannel: make(chan *chainntnfs.TxConfirmation, 1),

@ -1795,7 +1795,7 @@ func TestChannelLinkBandwidthConsistency(t *testing.T) {
coreLink.cfg.HodlMask = hodl.MaskFromFlags(hodl.ExitSettle) coreLink.cfg.HodlMask = hodl.MaskFromFlags(hodl.ExitSettle)
coreLink.cfg.DebugHTLC = true coreLink.cfg.DebugHTLC = true
estimator := &lnwallet.StaticFeeEstimator{FeePerKW: 6000} estimator := lnwallet.NewStaticFeeEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
t.Fatalf("unable to query fee estimator: %v", err) t.Fatalf("unable to query fee estimator: %v", err)
@ -2206,7 +2206,7 @@ func TestChannelLinkBandwidthConsistencyOverflow(t *testing.T) {
aliceMsgs = coreLink.cfg.Peer.(*mockPeer).sentMsgs aliceMsgs = coreLink.cfg.Peer.(*mockPeer).sentMsgs
) )
estimator := &lnwallet.StaticFeeEstimator{FeePerKW: 6000} estimator := lnwallet.NewStaticFeeEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
t.Fatalf("unable to query fee estimator: %v", err) t.Fatalf("unable to query fee estimator: %v", err)
@ -2453,7 +2453,7 @@ func TestChannelLinkTrimCircuitsPending(t *testing.T) {
// Compute the static fees that will be used to determine the // Compute the static fees that will be used to determine the
// correctness of Alice's bandwidth when forwarding HTLCs. // correctness of Alice's bandwidth when forwarding HTLCs.
estimator := &lnwallet.StaticFeeEstimator{FeePerKW: 6000} estimator := lnwallet.NewStaticFeeEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
t.Fatalf("unable to query fee estimator: %v", err) t.Fatalf("unable to query fee estimator: %v", err)
@ -2731,7 +2731,7 @@ func TestChannelLinkTrimCircuitsNoCommit(t *testing.T) {
// Compute the static fees that will be used to determine the // Compute the static fees that will be used to determine the
// correctness of Alice's bandwidth when forwarding HTLCs. // correctness of Alice's bandwidth when forwarding HTLCs.
estimator := &lnwallet.StaticFeeEstimator{FeePerKW: 6000} estimator := lnwallet.NewStaticFeeEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
t.Fatalf("unable to query fee estimator: %v", err) t.Fatalf("unable to query fee estimator: %v", err)
@ -2989,7 +2989,7 @@ func TestChannelLinkBandwidthChanReserve(t *testing.T) {
aliceMsgs = coreLink.cfg.Peer.(*mockPeer).sentMsgs aliceMsgs = coreLink.cfg.Peer.(*mockPeer).sentMsgs
) )
estimator := &lnwallet.StaticFeeEstimator{FeePerKW: 6000} estimator := lnwallet.NewStaticFeeEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
t.Fatalf("unable to query fee estimator: %v", err) t.Fatalf("unable to query fee estimator: %v", err)

@ -273,7 +273,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }
estimator := &lnwallet.StaticFeeEstimator{FeePerKW: 6000} estimator := lnwallet.NewStaticFeeEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err

@ -1131,7 +1131,7 @@ func TestHTLCSigNumber(t *testing.T) {
} }
// Calculate two values that will be below and above Bob's dust limit. // Calculate two values that will be below and above Bob's dust limit.
estimator := &StaticFeeEstimator{FeePerKW: 6000} estimator := NewStaticFeeEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
t.Fatalf("unable to get fee: %v", err) t.Fatalf("unable to get fee: %v", err)

@ -68,22 +68,33 @@ type FeeEstimator interface {
// StaticFeeEstimator will return a static value for all fee calculation // StaticFeeEstimator will return a static value for all fee calculation
// requests. It is designed to be replaced by a proper fee calculation // requests. It is designed to be replaced by a proper fee calculation
// implementation. // implementation. The fees are not accessible directly, because changing them
// would not be thread safe.
type StaticFeeEstimator struct { type StaticFeeEstimator struct {
// FeePerKW is the static fee rate in satoshis-per-vbyte that will be // feePerKW is the static fee rate in satoshis-per-vbyte that will be
// returned by this fee estimator. // returned by this fee estimator.
FeePerKW SatPerKWeight feePerKW SatPerKWeight
// RelayFee is the minimum fee rate required for transactions to be // relayFee is the minimum fee rate required for transactions to be
// relayed. // relayed.
RelayFee SatPerKWeight relayFee SatPerKWeight
}
// NewStaticFeeEstimator returns a new static fee estimator instance.
func NewStaticFeeEstimator(feePerKW,
relayFee SatPerKWeight) *StaticFeeEstimator {
return &StaticFeeEstimator{
feePerKW: feePerKW,
relayFee: relayFee,
}
} }
// EstimateFeePerKW will return a static value for fee calculations. // EstimateFeePerKW will return a static value for fee calculations.
// //
// NOTE: This method is part of the FeeEstimator interface. // NOTE: This method is part of the FeeEstimator interface.
func (e StaticFeeEstimator) EstimateFeePerKW(numBlocks uint32) (SatPerKWeight, error) { func (e StaticFeeEstimator) EstimateFeePerKW(numBlocks uint32) (SatPerKWeight, error) {
return e.FeePerKW, nil return e.feePerKW, nil
} }
// RelayFeePerKW returns the minimum fee rate required for transactions to be // RelayFeePerKW returns the minimum fee rate required for transactions to be
@ -91,7 +102,7 @@ func (e StaticFeeEstimator) EstimateFeePerKW(numBlocks uint32) (SatPerKWeight, e
// //
// NOTE: This method is part of the FeeEstimator interface. // NOTE: This method is part of the FeeEstimator interface.
func (e StaticFeeEstimator) RelayFeePerKW() SatPerKWeight { func (e StaticFeeEstimator) RelayFeePerKW() SatPerKWeight {
return e.RelayFee return e.relayFee
} }
// Start signals the FeeEstimator to start any processes or goroutines // Start signals the FeeEstimator to start any processes or goroutines

@ -74,9 +74,7 @@ func TestStaticFeeEstimator(t *testing.T) {
const feePerKw = lnwallet.FeePerKwFloor const feePerKw = lnwallet.FeePerKwFloor
feeEstimator := &lnwallet.StaticFeeEstimator{ feeEstimator := lnwallet.NewStaticFeeEstimator(feePerKw, 0)
FeePerKW: feePerKw,
}
if err := feeEstimator.Start(); err != nil { if err := feeEstimator.Start(); err != nil {
t.Fatalf("unable to start fee estimator: %v", err) t.Fatalf("unable to start fee estimator: %v", err)
} }

@ -368,7 +368,7 @@ func createTestWallet(tempTestDir string, miningNode *rpctest.Harness,
WalletController: wc, WalletController: wc,
Signer: signer, Signer: signer,
ChainIO: bio, ChainIO: bio,
FeeEstimator: lnwallet.StaticFeeEstimator{FeePerKW: 2500}, FeeEstimator: lnwallet.NewStaticFeeEstimator(2500, 0),
DefaultConstraints: channeldb.ChannelConstraints{ DefaultConstraints: channeldb.ChannelConstraints{
DustLimit: 500, DustLimit: 500,
MaxPendingAmount: lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) * 100, MaxPendingAmount: lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) * 100,
@ -2440,7 +2440,7 @@ func runTests(t *testing.T, walletDriver *lnwallet.WalletDriver,
} }
case "neutrino": case "neutrino":
feeEstimator = lnwallet.StaticFeeEstimator{FeePerKW: 62500} feeEstimator = lnwallet.NewStaticFeeEstimator(62500, 0)
// Set some package-level variable to speed up // Set some package-level variable to speed up
// operation for tests. // operation for tests.

@ -229,7 +229,7 @@ func CreateTestChannels() (*LightningChannel, *LightningChannel, func(), error)
return nil, nil, nil, err return nil, nil, nil, err
} }
estimator := &StaticFeeEstimator{FeePerKW: 6000} estimator := NewStaticFeeEstimator(6000, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err

@ -157,7 +157,7 @@ func TestPeerChannelClosureAcceptFeeInitiator(t *testing.T) {
dummyDeliveryScript), dummyDeliveryScript),
} }
estimator := lnwallet.StaticFeeEstimator{FeePerKW: 12500} estimator := lnwallet.NewStaticFeeEstimator(12500, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
t.Fatalf("unable to query fee estimator: %v", err) t.Fatalf("unable to query fee estimator: %v", err)
@ -447,7 +447,7 @@ func TestPeerChannelClosureFeeNegotiationsInitiator(t *testing.T) {
msg: respShutdown, msg: respShutdown,
} }
estimator := lnwallet.StaticFeeEstimator{FeePerKW: 12500} estimator := lnwallet.NewStaticFeeEstimator(12500, 0)
initiatorIdealFeeRate, err := estimator.EstimateFeePerKW(1) initiatorIdealFeeRate, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
t.Fatalf("unable to query fee estimator: %v", err) t.Fatalf("unable to query fee estimator: %v", err)

@ -202,7 +202,7 @@ func createTestPeer(notifier chainntnfs.ChainNotifier,
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }
estimator := &lnwallet.StaticFeeEstimator{FeePerKW: 12500} estimator := lnwallet.NewStaticFeeEstimator(12500, 0)
feePerKw, err := estimator.EstimateFeePerKW(1) feePerKw, err := estimator.EstimateFeePerKW(1)
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err