diff --git a/input/witnessgen.go b/input/witnessgen.go index 1b8c930e..0a2fafb3 100644 --- a/input/witnessgen.go +++ b/input/witnessgen.go @@ -35,9 +35,8 @@ type WitnessType interface { SizeUpperBound() (int, bool, error) // AddWeightEstimation adds the estimated size of the witness in bytes - // to the given weight estimator and returns the number of - // CSVs/CLTVs used by the script. - AddWeightEstimation(estimator *TxWeightEstimator) (int, int, error) + // to the given weight estimator. + AddWeightEstimation(e *TxWeightEstimator) error } // StandardWitnessType is a numeric representation of standard pre-defined types @@ -160,6 +159,12 @@ func (wt StandardWitnessType) String() string { case HtlcSecondLevelRevoke: return "HtlcSecondLevelRevoke" + case WitnessKeyHash: + return "WitnessKeyHash" + + case NestedWitnessKeyHash: + return "NestedWitnessKeyHash" + default: return fmt.Sprintf("Unknown WitnessType: %v", uint32(wt)) } @@ -368,42 +373,25 @@ func (wt StandardWitnessType) SizeUpperBound() (int, bool, error) { } // AddWeightEstimation adds the estimated size of the witness in bytes to the -// given weight estimator and returns the number of CSVs/CLTVs used by the -// script. +// given weight estimator. // // NOTE: This is part of the WitnessType interface. -func (wt StandardWitnessType) AddWeightEstimation( - estimator *TxWeightEstimator) (int, int, error) { - - var ( - csvCount = 0 - cltvCount = 0 - ) - +func (wt StandardWitnessType) AddWeightEstimation(e *TxWeightEstimator) error { // For fee estimation purposes, we'll now attempt to obtain an // upper bound on the weight this input will add when fully // populated. size, isNestedP2SH, err := wt.SizeUpperBound() if err != nil { - return 0, 0, err + return err } // If this is a nested P2SH input, then we'll need to factor in // the additional data push within the sigScript. if isNestedP2SH { - estimator.AddNestedP2WSHInput(size) + e.AddNestedP2WSHInput(size) } else { - estimator.AddWitnessInput(size) + e.AddWitnessInput(size) } - switch wt { - case CommitmentTimeLock, - HtlcOfferedTimeoutSecondLevel, - HtlcAcceptedSuccessSecondLevel: - csvCount++ - case HtlcOfferedRemoteTimeout: - cltvCount++ - } - - return csvCount, cltvCount, nil + return nil } diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 501ec3f8..623b11f9 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -339,7 +339,7 @@ func assertTxFeeRate(t *testing.T, tx *wire.MsgTx, outputAmt := tx.TxOut[0].Value fee := btcutil.Amount(inputAmt - outputAmt) - _, txWeight, _, _ := getWeightEstimate(inputs) + _, txWeight := getWeightEstimate(inputs) expectedFee := expectedFeeRate.FeeForWeight(txWeight) if fee != expectedFee { diff --git a/sweep/txgenerator.go b/sweep/txgenerator.go index 4e383c70..4c5c5bc7 100644 --- a/sweep/txgenerator.go +++ b/sweep/txgenerator.go @@ -3,6 +3,7 @@ package sweep import ( "fmt" "sort" + "strings" "github.com/btcsuite/btcd/blockchain" "github.com/btcsuite/btcd/txscript" @@ -171,10 +172,10 @@ func createSweepTx(inputs []input.Input, outputPkScript []byte, currentBlockHeight uint32, feePerKw lnwallet.SatPerKWeight, signer input.Signer) (*wire.MsgTx, error) { - inputs, txWeight, csvCount, cltvCount := getWeightEstimate(inputs) + inputs, txWeight := getWeightEstimate(inputs) - log.Infof("Creating sweep transaction for %v inputs (%v CSV, %v CLTV) "+ - "using %v sat/kw", len(inputs), csvCount, cltvCount, + log.Infof("Creating sweep transaction for %v inputs (%s) "+ + "using %v sat/kw", len(inputs), inputTypeSummary(inputs), int64(feePerKw)) txFee := feePerKw.FeeForWeight(txWeight) @@ -253,7 +254,7 @@ func createSweepTx(inputs []input.Input, outputPkScript []byte, // getWeightEstimate returns a weight estimate for the given inputs. // Additionally, it returns counts for the number of csv and cltv inputs. -func getWeightEstimate(inputs []input.Input) ([]input.Input, int64, int, int) { +func getWeightEstimate(inputs []input.Input) ([]input.Input, int64) { // We initialize a weight estimator so we can accurately asses the // amount of fees we need to pay for this sweep transaction. // @@ -268,15 +269,12 @@ func getWeightEstimate(inputs []input.Input) ([]input.Input, int64, int, int) { // For each output, use its witness type to determine the estimate // weight of its witness, and add it to the proper set of spendable // outputs. - var ( - sweepInputs []input.Input - csvCount, cltvCount int - ) + var sweepInputs []input.Input for i := range inputs { inp := inputs[i] wt := inp.WitnessType() - inpCsv, inpCltv, err := wt.AddWeightEstimation(&weightEstimate) + err := wt.AddWeightEstimation(&weightEstimate) if err != nil { log.Warn(err) @@ -284,12 +282,38 @@ func getWeightEstimate(inputs []input.Input) ([]input.Input, int64, int, int) { // given. continue } - csvCount += inpCsv - cltvCount += inpCltv + sweepInputs = append(sweepInputs, inp) } - txWeight := int64(weightEstimate.Weight()) - - return sweepInputs, txWeight, csvCount, cltvCount + return sweepInputs, int64(weightEstimate.Weight()) +} + +// inputSummary returns a string containing a human readable summary about the +// witness types of a list of inputs. +func inputTypeSummary(inputs []input.Input) string { + // Count each input by the string representation of its witness type. + // We also keep track of the keys so we can later sort by them to get + // a stable output. + counts := make(map[string]uint32) + keys := make([]string, 0, len(inputs)) + for _, i := range inputs { + key := i.WitnessType().String() + _, ok := counts[key] + if !ok { + counts[key] = 0 + keys = append(keys, key) + } + counts[key]++ + } + sort.Strings(keys) + + // Return a nice string representation of the counts by comma joining a + // slice. + var parts []string + for _, witnessType := range keys { + part := fmt.Sprintf("%d %s", counts[witnessType], witnessType) + parts = append(parts, part) + } + return strings.Join(parts, ", ") } diff --git a/sweep/txgenerator_test.go b/sweep/txgenerator_test.go index 571eb134..84484283 100644 --- a/sweep/txgenerator_test.go +++ b/sweep/txgenerator_test.go @@ -14,9 +14,10 @@ var ( input.HtlcOfferedRemoteTimeout, input.WitnessKeyHash, } - expectedWeight = int64(1459) - expectedCsv = 2 - expectedCltv = 1 + expectedWeight = int64(1459) + expectedSummary = "1 CommitmentTimeLock, 1 " + + "HtlcAcceptedSuccessSecondLevel, 1 HtlcOfferedRemoteTimeout, " + + "1 WitnessKeyHash" ) // TestWeightEstimate tests that the estimated weight and number of CSVs/CLTVs @@ -33,17 +34,14 @@ func TestWeightEstimate(t *testing.T) { )) } - _, weight, csv, cltv := getWeightEstimate(inputs) + _, weight := getWeightEstimate(inputs) if weight != expectedWeight { t.Fatalf("unexpected weight. expected %d but got %d.", expectedWeight, weight) } - if csv != expectedCsv { - t.Fatalf("unexpected csv count. expected %d but got %d.", - expectedCsv, csv) - } - if cltv != expectedCltv { - t.Fatalf("unexpected cltv count. expected %d but got %d.", - expectedCltv, cltv) + summary := inputTypeSummary(inputs) + if summary != expectedSummary { + t.Fatalf("unexpected summary. expected %s but got %s.", + expectedSummary, summary) } }