Merge pull request #2501 from cfromknecht/batch-preimage-writes

htlcswitch: batch preimage writes/consistency fix
This commit is contained in:
Olaoluwa Osuntokun 2019-02-21 17:00:00 -08:00 committed by GitHub
commit cbe0bf6a22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 710 additions and 209 deletions

@ -1548,10 +1548,7 @@ func createInitChannels(revocationWindow int) (*lnwallet.LightningChannel, *lnwa
Packager: channeldb.NewChannelPackager(shortChanID), Packager: channeldb.NewChannelPackager(shortChanID),
} }
pCache := &mockPreimageCache{ pCache := newMockPreimageCache()
// hash -> preimage
preimageMap: make(map[[32]byte][]byte),
}
aliceSigner := &mockSigner{aliceKeyPriv} aliceSigner := &mockSigner{aliceKeyPriv}
bobSigner := &mockSigner{bobKeyPriv} bobSigner := &mockSigner{bobKeyPriv}

@ -1,10 +1,10 @@
package channeldb package channeldb
import ( import (
"crypto/sha256"
"fmt" "fmt"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/lntypes"
) )
var ( var (
@ -70,12 +70,42 @@ func (d *DB) NewWitnessCache() *WitnessCache {
} }
} }
// AddWitness adds a new witness of wType to the witness cache. The type of the // witnessEntry is a key-value struct that holds each key -> witness pair, used
// witness will be used to map the witness to the key that will be used to look // when inserting records into the cache.
// it up. type witnessEntry struct {
// key []byte
// TODO(roasbeef): fake closure to map instead a constructor? witness []byte
func (w *WitnessCache) AddWitness(wType WitnessType, witness []byte) error { }
// AddSha256Witnesses adds a batch of new sha256 preimages into the witness
// cache. This is an alias for AddWitnesses that uses Sha256HashWitness as the
// preimages' witness type.
func (w *WitnessCache) AddSha256Witnesses(preimages ...lntypes.Preimage) error {
// Optimistically compute the preimages' hashes before attempting to
// start the db transaction.
entries := make([]witnessEntry, 0, len(preimages))
for i := range preimages {
hash := preimages[i].Hash()
entries = append(entries, witnessEntry{
key: hash[:],
witness: preimages[i][:],
})
}
return w.addWitnessEntries(Sha256HashWitness, entries)
}
// addWitnessEntries inserts the witnessEntry key-value pairs into the cache,
// using the appropriate witness type to segment the namespace of possible
// witness types.
func (w *WitnessCache) addWitnessEntries(wType WitnessType,
entries []witnessEntry) error {
// Exit early if there are no witnesses to add.
if len(entries) == 0 {
return nil
}
return w.db.Batch(func(tx *bbolt.Tx) error { return w.db.Batch(func(tx *bbolt.Tx) error {
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
if err != nil { if err != nil {
@ -93,23 +123,32 @@ func (w *WitnessCache) AddWitness(wType WitnessType, witness []byte) error {
return err return err
} }
// Now that we have the proper bucket for this witness, we'll map the for _, entry := range entries {
// witness type to the proper key. err = witnessTypeBucket.Put(entry.key, entry.witness)
var witnessKey []byte if err != nil {
switch wType { return err
case Sha256HashWitness: }
key := sha256.Sum256(witness)
witnessKey = key[:]
} }
return witnessTypeBucket.Put(witnessKey, witness) return nil
}) })
} }
// LookupWitness attempts to lookup a witness according to its type and also // LookupSha256Witness attempts to lookup the preimage for a sha256 hash. If
// the witness isn't found, ErrNoWitnesses will be returned.
func (w *WitnessCache) LookupSha256Witness(hash lntypes.Hash) (lntypes.Preimage, error) {
witness, err := w.lookupWitness(Sha256HashWitness, hash[:])
if err != nil {
return lntypes.Preimage{}, err
}
return lntypes.MakePreimage(witness)
}
// lookupWitness attempts to lookup a witness according to its type and also
// its witness key. In the case that the witness isn't found, ErrNoWitnesses // its witness key. In the case that the witness isn't found, ErrNoWitnesses
// will be returned. // will be returned.
func (w *WitnessCache) LookupWitness(wType WitnessType, witnessKey []byte) ([]byte, error) { func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]byte, error) {
var witness []byte var witness []byte
err := w.db.View(func(tx *bbolt.Tx) error { err := w.db.View(func(tx *bbolt.Tx) error {
witnessBucket := tx.Bucket(witnessBucketKey) witnessBucket := tx.Bucket(witnessBucketKey)
@ -143,8 +182,13 @@ func (w *WitnessCache) LookupWitness(wType WitnessType, witnessKey []byte) ([]by
return witness, nil return witness, nil
} }
// DeleteWitness attempts to delete a particular witness from the database. // DeleteSha256Witness attempts to delete a sha256 preimage identified by hash.
func (w *WitnessCache) DeleteWitness(wType WitnessType, witnessKey []byte) error { func (w *WitnessCache) DeleteSha256Witness(hash lntypes.Hash) error {
return w.deleteWitness(Sha256HashWitness, hash[:])
}
// deleteWitness attempts to delete a particular witness from the database.
func (w *WitnessCache) deleteWitness(wType WitnessType, witnessKey []byte) error {
return w.db.Batch(func(tx *bbolt.Tx) error { return w.db.Batch(func(tx *bbolt.Tx) error {
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
if err != nil { if err != nil {

@ -2,13 +2,14 @@ package channeldb
import ( import (
"crypto/sha256" "crypto/sha256"
"reflect"
"testing" "testing"
"github.com/lightningnetwork/lnd/lntypes"
) )
// TestWitnessCacheRetrieval tests that we're able to add and lookup new // TestWitnessCacheSha256Retrieval tests that we're able to add and lookup new
// witnesses to the witness cache. // sha256 preimages to the witness cache.
func TestWitnessCacheRetrieval(t *testing.T) { func TestWitnessCacheSha256Retrieval(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := makeTestDB() cdb, cleanUp, err := makeTestDB()
@ -19,33 +20,41 @@ func TestWitnessCacheRetrieval(t *testing.T) {
wCache := cdb.NewWitnessCache() wCache := cdb.NewWitnessCache()
// We'll be attempting to add then lookup a d simple hash witness // We'll be attempting to add then lookup two simple sha256 preimages
// within this test. // within this test.
witness := rev[:] preimage1 := lntypes.Preimage(rev)
witnessKey := sha256.Sum256(witness) preimage2 := lntypes.Preimage(key)
// First, we'll attempt to add the witness to the database. preimages := []lntypes.Preimage{preimage1, preimage2}
err = wCache.AddWitness(Sha256HashWitness, witness) hashes := []lntypes.Hash{preimage1.Hash(), preimage2.Hash()}
// First, we'll attempt to add the preimages to the database.
err = wCache.AddSha256Witnesses(preimages...)
if err != nil { if err != nil {
t.Fatalf("unable to add witness: %v", err) t.Fatalf("unable to add witness: %v", err)
} }
// With the witness stored, we'll now attempt to look it up. We should // With the preimages stored, we'll now attempt to look them up.
// get back the *exact* same witness as we originally stored. for i, hash := range hashes {
dbWitness, err := wCache.LookupWitness(Sha256HashWitness, witnessKey[:]) preimage := preimages[i]
if err != nil {
t.Fatalf("unable to look up witness: %v", err)
}
if !reflect.DeepEqual(witness, dbWitness[:]) { // We should get back the *exact* same preimage as we originally
t.Fatalf("witnesses don't match: expected %x, got %x", // stored.
witness[:], dbWitness[:]) dbPreimage, err := wCache.LookupSha256Witness(hash)
if err != nil {
t.Fatalf("unable to look up witness: %v", err)
}
if preimage != dbPreimage {
t.Fatalf("witnesses don't match: expected %x, got %x",
preimage[:], dbPreimage[:])
}
} }
} }
// TestWitnessCacheDeletion tests that we're able to delete a single witness, // TestWitnessCacheSha256Deletion tests that we're able to delete a single
// and also a class of witnesses from the cache. // sha256 preimage, and also a class of witnesses from the cache.
func TestWitnessCacheDeletion(t *testing.T) { func TestWitnessCacheSha256Deletion(t *testing.T) {
t.Parallel() t.Parallel()
cdb, cleanUp, err := makeTestDB() cdb, cleanUp, err := makeTestDB()
@ -56,37 +65,39 @@ func TestWitnessCacheDeletion(t *testing.T) {
wCache := cdb.NewWitnessCache() wCache := cdb.NewWitnessCache()
// We'll start by adding two witnesses to the cache. // We'll start by adding two preimages to the cache.
witness1 := rev[:] preimage1 := lntypes.Preimage(key)
witness1Key := sha256.Sum256(witness1) hash1 := preimage1.Hash()
if err := wCache.AddWitness(Sha256HashWitness, witness1); err != nil {
preimage2 := lntypes.Preimage(rev)
hash2 := preimage2.Hash()
if err := wCache.AddSha256Witnesses(preimage1); err != nil {
t.Fatalf("unable to add witness: %v", err) t.Fatalf("unable to add witness: %v", err)
} }
witness2 := key[:] if err := wCache.AddSha256Witnesses(preimage2); err != nil {
witness2Key := sha256.Sum256(witness2)
if err := wCache.AddWitness(Sha256HashWitness, witness2); err != nil {
t.Fatalf("unable to add witness: %v", err) t.Fatalf("unable to add witness: %v", err)
} }
// We'll now delete the first witness. If we attempt to look it up, we // We'll now delete the first preimage. If we attempt to look it up, we
// should get ErrNoWitnesses. // should get ErrNoWitnesses.
err = wCache.DeleteWitness(Sha256HashWitness, witness1Key[:]) err = wCache.DeleteSha256Witness(hash1)
if err != nil { if err != nil {
t.Fatalf("unable to delete witness: %v", err) t.Fatalf("unable to delete witness: %v", err)
} }
_, err = wCache.LookupWitness(Sha256HashWitness, witness1Key[:]) _, err = wCache.LookupSha256Witness(hash1)
if err != ErrNoWitnesses { if err != ErrNoWitnesses {
t.Fatalf("expected ErrNoWitnesses instead got: %v", err) t.Fatalf("expected ErrNoWitnesses instead got: %v", err)
} }
// Next, we'll attempt to delete the entire witness class itself. When // Next, we'll attempt to delete the entire witness class itself. When
// we try to lookup the second witness, we should again get // we try to lookup the second preimage, we should again get
// ErrNoWitnesses. // ErrNoWitnesses.
if err := wCache.DeleteWitnessClass(Sha256HashWitness); err != nil { if err := wCache.DeleteWitnessClass(Sha256HashWitness); err != nil {
t.Fatalf("unable to delete witness class: %v", err) t.Fatalf("unable to delete witness class: %v", err)
} }
_, err = wCache.LookupWitness(Sha256HashWitness, witness2Key[:]) _, err = wCache.LookupSha256Witness(hash2)
if err != ErrNoWitnesses { if err != ErrNoWitnesses {
t.Fatalf("expected ErrNoWitnesses instead got: %v", err) t.Fatalf("expected ErrNoWitnesses instead got: %v", err)
} }
@ -107,8 +118,121 @@ func TestWitnessCacheUnknownWitness(t *testing.T) {
// We'll attempt to add a new, undefined witness type to the database. // We'll attempt to add a new, undefined witness type to the database.
// We should get an error. // We should get an error.
err = wCache.AddWitness(234, key[:]) err = wCache.legacyAddWitnesses(234, key[:])
if err != ErrUnknownWitnessType { if err != ErrUnknownWitnessType {
t.Fatalf("expected ErrUnknownWitnessType, got %v", err) t.Fatalf("expected ErrUnknownWitnessType, got %v", err)
} }
} }
// TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves
// identically to the insertion via the generalized interface.
func TestAddSha256Witnesses(t *testing.T) {
cdb, cleanUp, err := makeTestDB()
if err != nil {
t.Fatalf("unable to make test database: %v", err)
}
defer cleanUp()
wCache := cdb.NewWitnessCache()
// We'll start by adding a witnesses to the cache using the generic
// AddWitnesses method.
witness1 := rev[:]
preimage1 := lntypes.Preimage(rev)
hash1 := preimage1.Hash()
witness2 := key[:]
preimage2 := lntypes.Preimage(key)
hash2 := preimage2.Hash()
var (
witnesses = [][]byte{witness1, witness2}
preimages = []lntypes.Preimage{preimage1, preimage2}
hashes = []lntypes.Hash{hash1, hash2}
)
err = wCache.legacyAddWitnesses(Sha256HashWitness, witnesses...)
if err != nil {
t.Fatalf("unable to add witness: %v", err)
}
for i, hash := range hashes {
preimage := preimages[i]
dbPreimage, err := wCache.LookupSha256Witness(hash)
if err != nil {
t.Fatalf("unable to lookup witness: %v", err)
}
// Assert that the retrieved witness matches the original.
if dbPreimage != preimage {
t.Fatalf("retrieved witness mismatch, want: %x, "+
"got: %x", preimage, dbPreimage)
}
// We'll now delete the witness, as we'll be reinserting it
// using the specialized AddSha256Witnesses method.
err = wCache.DeleteSha256Witness(hash)
if err != nil {
t.Fatalf("unable to delete witness: %v", err)
}
}
// Now, add the same witnesses using the type-safe interface for
// lntypes.Preimages..
err = wCache.AddSha256Witnesses(preimages...)
if err != nil {
t.Fatalf("unable to add sha256 preimage: %v", err)
}
// Finally, iterate over the keys and assert that the returned witnesses
// match the original witnesses. This asserts that the specialized
// insertion method behaves identically to the generalized interface.
for i, hash := range hashes {
preimage := preimages[i]
dbPreimage, err := wCache.LookupSha256Witness(hash)
if err != nil {
t.Fatalf("unable to lookup witness: %v", err)
}
// Assert that the retrieved witness matches the original.
if dbPreimage != preimage {
t.Fatalf("retrieved witness mismatch, want: %x, "+
"got: %x", preimage, dbPreimage)
}
}
}
// legacyAddWitnesses adds a batch of new witnesses of wType to the witness
// cache. The type of the witness will be used to map each witness to the key
// that will be used to look it up. All witnesses should be of the same
// WitnessType.
//
// NOTE: Previously this method exposed a generic interface for adding
// witnesses, which has since been deprecated in favor of a strongly typed
// interface for each witness class. We keep this method around to assert the
// correctness of specialized witness adding methods.
func (w *WitnessCache) legacyAddWitnesses(wType WitnessType,
witnesses ...[]byte) error {
// Optimistically compute the witness keys before attempting to start
// the db transaction.
entries := make([]witnessEntry, 0, len(witnesses))
for _, witness := range witnesses {
// Map each witness to its key by applying the appropriate
// transformation for the given witness type.
switch wType {
case Sha256HashWitness:
key := sha256.Sum256(witness)
entries = append(entries, witnessEntry{
key: key[:],
witness: witness,
})
default:
return ErrUnknownWitnessType
}
}
return w.addWitnessEntries(wType, entries)
}

@ -11,6 +11,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -40,7 +41,7 @@ type WitnessSubscription struct {
// sent over. // sent over.
// //
// TODO(roasbeef): couple with WitnessType? // TODO(roasbeef): couple with WitnessType?
WitnessUpdates <-chan []byte WitnessUpdates <-chan lntypes.Preimage
// CancelSubscription is a function closure that should be used by a // CancelSubscription is a function closure that should be used by a
// client to cancel the subscription once they are no longer interested // client to cancel the subscription once they are no longer interested
@ -62,10 +63,12 @@ type WitnessBeacon interface {
// LookupPreImage attempts to lookup a preimage in the global cache. // LookupPreImage attempts to lookup a preimage in the global cache.
// True is returned for the second argument if the preimage is found. // True is returned for the second argument if the preimage is found.
LookupPreimage(payhash []byte) ([]byte, bool) LookupPreimage(payhash lntypes.Hash) (lntypes.Preimage, bool)
// AddPreImage adds a newly discovered preimage to the global cache. // AddPreimages adds a batch of newly discovered preimages to the global
AddPreimage(pre []byte) error // cache, and also signals any subscribers of the newly discovered
// witness.
AddPreimages(preimages ...lntypes.Preimage) error
} }
// ChannelArbitratorConfig contains all the functionality that the // ChannelArbitratorConfig contains all the functionality that the
@ -1127,7 +1130,7 @@ func (c *ChannelArbitrator) checkChainActions(height uint32,
// know the pre-image and it's close to timing out. We need to // know the pre-image and it's close to timing out. We need to
// ensure that we claim the funds that our rightfully ours // ensure that we claim the funds that our rightfully ours
// on-chain. // on-chain.
if _, ok := c.cfg.PreimageDB.LookupPreimage(htlc.RHash[:]); !ok { if _, ok := c.cfg.PreimageDB.LookupPreimage(htlc.RHash); !ok {
continue continue
} }
haveChainActions = haveChainActions || c.shouldGoOnChain( haveChainActions = haveChainActions || c.shouldGoOnChain(
@ -1204,13 +1207,12 @@ func (c *ChannelArbitrator) checkChainActions(height uint32,
// either learn of it eventually from the outgoing HTLC, or the sender // either learn of it eventually from the outgoing HTLC, or the sender
// will timeout the HTLC. // will timeout the HTLC.
for _, htlc := range c.activeHTLCs.incomingHTLCs { for _, htlc := range c.activeHTLCs.incomingHTLCs {
payHash := htlc.RHash
// If we have the pre-image, then we should go on-chain to // If we have the pre-image, then we should go on-chain to
// redeem the HTLC immediately. // redeem the HTLC immediately.
if _, ok := c.cfg.PreimageDB.LookupPreimage(payHash[:]); ok { if _, ok := c.cfg.PreimageDB.LookupPreimage(htlc.RHash); ok {
log.Tracef("ChannelArbitrator(%v): preimage for "+ log.Tracef("ChannelArbitrator(%v): preimage for "+
"htlc=%x is known!", c.cfg.ChanPoint, payHash[:]) "htlc=%x is known!", c.cfg.ChanPoint,
htlc.RHash[:])
actionMap[HtlcClaimAction] = append( actionMap[HtlcClaimAction] = append(
actionMap[HtlcClaimAction], htlc, actionMap[HtlcClaimAction], htlc,
@ -1220,7 +1222,7 @@ func (c *ChannelArbitrator) checkChainActions(height uint32,
log.Tracef("ChannelArbitrator(%v): watching chain to decide "+ log.Tracef("ChannelArbitrator(%v): watching chain to decide "+
"action for incoming htlc=%x", c.cfg.ChanPoint, "action for incoming htlc=%x", c.cfg.ChanPoint,
payHash[:]) htlc.RHash[:])
// Otherwise, we don't yet have the pre-image, but should watch // Otherwise, we don't yet have the pre-image, but should watch
// on-chain to see if either: the remote party times out the // on-chain to see if either: the remote party times out the

@ -2,12 +2,12 @@ package contractcourt
import ( import (
"bytes" "bytes"
"crypto/sha256"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lntypes"
) )
// htlcIncomingContestResolver is a ContractResolver that's able to resolve an // htlcIncomingContestResolver is a ContractResolver that's able to resolve an
@ -74,11 +74,11 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
// resolver with the preimage we learn of. This should be called once // resolver with the preimage we learn of. This should be called once
// the preimage is revealed so the inner resolver can properly complete // the preimage is revealed so the inner resolver can properly complete
// its duties. // its duties.
applyPreimage := func(preimage []byte) { applyPreimage := func(preimage lntypes.Preimage) {
copy(h.htlcResolution.Preimage[:], preimage) h.htlcResolution.Preimage = preimage
log.Infof("%T(%v): extracted preimage=%x from beacon!", h, log.Infof("%T(%v): extracted preimage=%v from beacon!", h,
h.htlcResolution.ClaimOutpoint, preimage[:]) h.htlcResolution.ClaimOutpoint, preimage)
// If this our commitment transaction, then we'll need to // If this our commitment transaction, then we'll need to
// populate the witness for the second-level HTLC transaction. // populate the witness for the second-level HTLC transaction.
@ -93,8 +93,6 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
// preimage. // preimage.
h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[3] = preimage[:] h.htlcResolution.SignedSuccessTx.TxIn[0].Witness[3] = preimage[:]
} }
copy(h.htlcResolution.Preimage[:], preimage[:])
} }
// If the HTLC hasn't expired yet, then we may still be able to claim // If the HTLC hasn't expired yet, then we may still be able to claim
@ -116,12 +114,12 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
// With the epochs and preimage subscriptions initialized, we'll query // With the epochs and preimage subscriptions initialized, we'll query
// to see if we already know the preimage. // to see if we already know the preimage.
preimage, ok := h.PreimageDB.LookupPreimage(h.payHash[:]) preimage, ok := h.PreimageDB.LookupPreimage(h.payHash)
if ok { if ok {
// If we do, then this means we can claim the HTLC! However, // If we do, then this means we can claim the HTLC! However,
// we don't know how to ourselves, so we'll return our inner // we don't know how to ourselves, so we'll return our inner
// resolver which has the knowledge to do so. // resolver which has the knowledge to do so.
applyPreimage(preimage[:]) applyPreimage(preimage)
return &h.htlcSuccessResolver, nil return &h.htlcSuccessResolver, nil
} }
@ -131,8 +129,8 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
case preimage := <-preimageSubscription.WitnessUpdates: case preimage := <-preimageSubscription.WitnessUpdates:
// If this isn't our preimage, then we'll continue // If this isn't our preimage, then we'll continue
// onwards. // onwards.
newHash := sha256.Sum256(preimage) hash := preimage.Hash()
preimageMatches := bytes.Equal(newHash[:], h.payHash[:]) preimageMatches := bytes.Equal(hash[:], h.payHash[:])
if !preimageMatches { if !preimageMatches {
continue continue
} }

@ -2,13 +2,15 @@ package contractcourt
import ( import (
"fmt" "fmt"
"github.com/lightningnetwork/lnd/input"
"io" "io"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
) )
// htlcOutgoingContestResolver is a ContractResolver that's able to resolve an // htlcOutgoingContestResolver is a ContractResolver that's able to resolve an
@ -58,38 +60,47 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) {
// If this is the remote party's commitment, then we'll be // If this is the remote party's commitment, then we'll be
// looking for them to spend using the second-level success // looking for them to spend using the second-level success
// transaction. // transaction.
var preimage [32]byte var preimageBytes []byte
if h.htlcResolution.SignedTimeoutTx == nil { if h.htlcResolution.SignedTimeoutTx == nil {
// The witness stack when the remote party sweeps the // The witness stack when the remote party sweeps the
// output to them looks like: // output to them looks like:
// //
// * <sender sig> <recvr sig> <preimage> <witness script> // * <sender sig> <recvr sig> <preimage> <witness script>
copy(preimage[:], spendingInput.Witness[3]) preimageBytes = spendingInput.Witness[3]
} else { } else {
// Otherwise, they'll be spending directly from our // Otherwise, they'll be spending directly from our
// commitment output. In which case the witness stack // commitment output. In which case the witness stack
// looks like: // looks like:
// //
// * <sig> <preimage> <witness script> // * <sig> <preimage> <witness script>
copy(preimage[:], spendingInput.Witness[1]) preimageBytes = spendingInput.Witness[1]
} }
log.Infof("%T(%v): extracting preimage=%x from on-chain "+ preimage, err := lntypes.MakePreimage(preimageBytes)
"spend!", h, h.htlcResolution.ClaimOutpoint, preimage[:]) if err != nil {
return nil, err
}
log.Infof("%T(%v): extracting preimage=%v from on-chain "+
"spend!", h, h.htlcResolution.ClaimOutpoint,
preimage)
// With the preimage obtained, we can now add it to the global // With the preimage obtained, we can now add it to the global
// cache. // cache.
if err := h.PreimageDB.AddPreimage(preimage[:]); err != nil { if err := h.PreimageDB.AddPreimages(preimage); err != nil {
log.Errorf("%T(%v): unable to add witness to cache", log.Errorf("%T(%v): unable to add witness to cache",
h, h.htlcResolution.ClaimOutpoint) h, h.htlcResolution.ClaimOutpoint)
} }
var pre [32]byte
copy(pre[:], preimage[:])
// Finally, we'll send the clean up message, mark ourselves as // Finally, we'll send the clean up message, mark ourselves as
// resolved, then exit. // resolved, then exit.
if err := h.DeliverResolutionMsg(ResolutionMsg{ if err := h.DeliverResolutionMsg(ResolutionMsg{
SourceChan: h.ShortChanID, SourceChan: h.ShortChanID,
HtlcIndex: h.htlcIndex, HtlcIndex: h.htlcIndex,
PreImage: &preimage, PreImage: &pre,
}); err != nil { }); err != nil {
return nil, err return nil, err
} }

@ -3,15 +3,16 @@ package contractcourt
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"github.com/lightningnetwork/lnd/input"
"io" "io"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/sweep" "github.com/lightningnetwork/lnd/sweep"
) )
@ -41,7 +42,7 @@ type htlcSuccessResolver struct {
broadcastHeight uint32 broadcastHeight uint32
// payHash is the payment hash of the original HTLC extended to us. // payHash is the payment hash of the original HTLC extended to us.
payHash [32]byte payHash lntypes.Hash
// sweepTx will be non-nil if we've already crafted a transaction to // sweepTx will be non-nil if we've already crafted a transaction to
// sweep a direct HTLC output. This is only a concern if we're sweeping // sweep a direct HTLC output. This is only a concern if we're sweeping

@ -333,6 +333,12 @@ type channelLink struct {
// commitment fee every time it fires. // commitment fee every time it fires.
updateFeeTimer *time.Timer updateFeeTimer *time.Timer
// uncommittedPreimages stores a list of all preimages that have been
// learned since receiving the last CommitSig from the remote peer. The
// batch will be flushed just before accepting the subsequent CommitSig
// or on shutdown to avoid doing a write for each preimage received.
uncommittedPreimages []lntypes.Preimage
sync.RWMutex sync.RWMutex
wg sync.WaitGroup wg sync.WaitGroup
@ -449,6 +455,18 @@ func (l *channelLink) Stop() {
close(l.quit) close(l.quit)
l.wg.Wait() l.wg.Wait()
// As a final precaution, we will attempt to flush any uncommitted
// preimages to the preimage cache. The preimages should be re-delivered
// after channel reestablishment, however this adds an extra layer of
// protection in case the peer never returns. Without this, we will be
// unable to settle any contracts depending on the preimages even though
// we had learned them at some point.
err := l.cfg.PreimageCache.AddPreimages(l.uncommittedPreimages...)
if err != nil {
log.Errorf("Unable to add preimages=%v to cache: %v",
l.uncommittedPreimages, err)
}
} }
// WaitForShutdown blocks until the link finishes shutting down, which includes // WaitForShutdown blocks until the link finishes shutting down, which includes
@ -1412,17 +1430,11 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
// TODO(roasbeef): pipeline to switch // TODO(roasbeef): pipeline to switch
// As we've learned of a new preimage for the first time, we'll // Add the newly discovered preimage to our growing list of
// add it to our preimage cache. By doing this, we ensure // uncommitted preimage. These will be written to the witness
// any contested contracts watched by any on-chain arbitrators // cache just before accepting the next commitment signature
// can now sweep this HTLC on-chain. // from the remote peer.
go func() { l.uncommittedPreimages = append(l.uncommittedPreimages, pre)
err := l.cfg.PreimageCache.AddPreimage(pre[:])
if err != nil {
l.errorf("unable to add preimage=%x to "+
"cache", pre[:])
}
}()
case *lnwire.UpdateFailMalformedHTLC: case *lnwire.UpdateFailMalformedHTLC:
// Convert the failure type encoded within the HTLC fail // Convert the failure type encoded within the HTLC fail
@ -1475,10 +1487,39 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) {
} }
case *lnwire.CommitSig: case *lnwire.CommitSig:
// Since we may have learned new preimages for the first time,
// we'll add them to our preimage cache. By doing this, we
// ensure any contested contracts watched by any on-chain
// arbitrators can now sweep this HTLC on-chain. We delay
// committing the preimages until just before accepting the new
// remote commitment, as afterwards the peer won't resend the
// Settle messages on the next channel reestablishment. Doing so
// allows us to more effectively batch this operation, instead
// of doing a single write per preimage.
err := l.cfg.PreimageCache.AddPreimages(
l.uncommittedPreimages...,
)
if err != nil {
l.fail(
LinkFailureError{code: ErrInternalError},
"unable to add preimages=%v to cache: %v",
l.uncommittedPreimages, err,
)
return
}
// Instead of truncating the slice to conserve memory
// allocations, we simply set the uncommitted preimage slice to
// nil so that a new one will be initialized if any more
// witnesses are discovered. We do this maximum size of the
// slice can occupy 15KB, and want to ensure we release that
// memory back to the runtime.
l.uncommittedPreimages = nil
// We just received a new updates to our local commitment // We just received a new updates to our local commitment
// chain, validate this new commitment, closing the link if // chain, validate this new commitment, closing the link if
// invalid. // invalid.
err := l.channel.ReceiveNewCommitment(msg.CommitSig, msg.HtlcSigs) err = l.channel.ReceiveNewCommitment(msg.CommitSig, msg.HtlcSigs)
if err != nil { if err != nil {
// If we were unable to reconstruct their proposed // If we were unable to reconstruct their proposed
// commitment, then we'll examine the type of error. If // commitment, then we'll examine the type of error. If

@ -3,6 +3,7 @@ package htlcswitch
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
@ -28,6 +29,7 @@ import (
"github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnpeer"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/ticker"
@ -1532,10 +1534,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta) invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta)
) )
pCache := &mockPreimageCache{ pCache := newMockPreimageCache()
// hash -> preimage
preimageMap: make(map[[32]byte][]byte),
}
aliceDb := aliceChannel.State().Db aliceDb := aliceChannel.State().Db
aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb) aliceSwitch, err := initSwitchWithDB(testStartingHeight, aliceDb)
@ -4042,10 +4041,7 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta) invoiceRegistry = newMockRegistry(globalPolicy.TimeLockDelta)
pCache = &mockPreimageCache{ pCache = newMockPreimageCache()
// hash -> preimage
preimageMap: make(map[[32]byte][]byte),
}
) )
aliceDb := aliceChannel.State().Db aliceDb := aliceChannel.State().Db
@ -4120,6 +4116,29 @@ func restartLink(aliceChannel *lnwallet.LightningChannel, aliceSwitch *Switch,
// gnerateHtlc generates a simple payment from Bob to Alice. // gnerateHtlc generates a simple payment from Bob to Alice.
func generateHtlc(t *testing.T, coreLink *channelLink, func generateHtlc(t *testing.T, coreLink *channelLink,
bobChannel *lnwallet.LightningChannel, id uint64) *lnwire.UpdateAddHTLC { bobChannel *lnwallet.LightningChannel, id uint64) *lnwire.UpdateAddHTLC {
t.Helper()
htlc, invoice := generateHtlcAndInvoice(t, id)
// We must add the invoice to the registry, such that Alice
// expects this payment.
err := coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(
*invoice)
if err != nil {
t.Fatalf("unable to add invoice to registry: %v", err)
}
return htlc
}
// generateHtlcAndInvoice generates an invoice and a single hop htlc to send to
// the receiver.
func generateHtlcAndInvoice(t *testing.T,
id uint64) (*lnwire.UpdateAddHTLC, *channeldb.Invoice) {
t.Helper()
htlcAmt := lnwire.NewMSatFromSatoshis(10000) htlcAmt := lnwire.NewMSatFromSatoshis(10000)
hops := []ForwardingInfo{ hops := []ForwardingInfo{
{ {
@ -4130,27 +4149,28 @@ func generateHtlc(t *testing.T, coreLink *channelLink,
}, },
} }
blob, err := generateRoute(hops...) blob, err := generateRoute(hops...)
if err != nil {
t.Fatalf("unable to generate route: %v", err)
}
invoice, htlc, err := generatePayment(htlcAmt, htlcAmt, 144, invoice, htlc, err := generatePayment(htlcAmt, htlcAmt, 144,
blob) blob)
if err != nil { if err != nil {
t.Fatalf("unable to create payment: %v", err) t.Fatalf("unable to create payment: %v", err)
} }
// We must add the invoice to the registry, such that Alice
// expects this payment.
err = coreLink.cfg.Registry.(*mockInvoiceRegistry).AddInvoice(
*invoice)
if err != nil {
t.Fatalf("unable to add invoice to registry: %v", err)
}
htlc.ID = id htlc.ID = id
return htlc
return htlc, invoice
} }
// sendHtlcBobToAlice sends an HTLC from Bob to Alice, that pays to a preimage // sendHtlcBobToAlice sends an HTLC from Bob to Alice, that pays to a preimage
// already in Alice's registry. // already in Alice's registry.
func sendHtlcBobToAlice(t *testing.T, aliceLink ChannelLink, func sendHtlcBobToAlice(t *testing.T, aliceLink ChannelLink,
bobChannel *lnwallet.LightningChannel, htlc *lnwire.UpdateAddHTLC) { bobChannel *lnwallet.LightningChannel, htlc *lnwire.UpdateAddHTLC) {
t.Helper()
_, err := bobChannel.AddHTLC(htlc, nil) _, err := bobChannel.AddHTLC(htlc, nil)
if err != nil { if err != nil {
t.Fatalf("bob failed adding htlc: %v", err) t.Fatalf("bob failed adding htlc: %v", err)
@ -4159,10 +4179,70 @@ func sendHtlcBobToAlice(t *testing.T, aliceLink ChannelLink,
aliceLink.HandleChannelUpdate(htlc) aliceLink.HandleChannelUpdate(htlc)
} }
// sendHtlcAliceToBob sends an HTLC from Alice to Bob, by first committing the
// HTLC in the circuit map, then delivering the outgoing packet to Alice's link.
// The HTLC will be sent to Bob via Alice's message stream.
func sendHtlcAliceToBob(t *testing.T, aliceLink ChannelLink, htlcID int,
htlc *lnwire.UpdateAddHTLC) {
t.Helper()
circuitMap := aliceLink.(*channelLink).cfg.Switch.circuits
fwdActions, err := circuitMap.CommitCircuits(
&PaymentCircuit{
Incoming: CircuitKey{
HtlcID: uint64(htlcID),
},
PaymentHash: htlc.PaymentHash,
},
)
if err != nil {
t.Fatalf("unable to commit circuit: %v", err)
}
if len(fwdActions.Adds) != 1 {
t.Fatalf("expected 1 adds, found %d", len(fwdActions.Adds))
}
aliceLink.HandleSwitchPacket(&htlcPacket{
incomingHTLCID: uint64(htlcID),
htlc: htlc,
})
}
// receiveHtlcAliceToBob pulls the next message from Alice's message stream,
// asserts that it is an UpdateAddHTLC, then applies it to Bob's state machine.
func receiveHtlcAliceToBob(t *testing.T, aliceMsgs <-chan lnwire.Message,
bobChannel *lnwallet.LightningChannel) {
t.Helper()
var msg lnwire.Message
select {
case msg = <-aliceMsgs:
case <-time.After(15 * time.Second):
t.Fatalf("did not received htlc from alice")
}
htlcAdd, ok := msg.(*lnwire.UpdateAddHTLC)
if !ok {
t.Fatalf("expected UpdateAddHTLC, got %T", msg)
}
_, err := bobChannel.ReceiveHTLC(htlcAdd)
if err != nil {
t.Fatalf("bob failed receiving htlc: %v", err)
}
}
// sendCommitSigBobToAlice makes Bob sign a new commitment and send it to // sendCommitSigBobToAlice makes Bob sign a new commitment and send it to
// Alice, asserting that it signs expHtlcs number of HTLCs. // Alice, asserting that it signs expHtlcs number of HTLCs.
func sendCommitSigBobToAlice(t *testing.T, aliceLink ChannelLink, func sendCommitSigBobToAlice(t *testing.T, aliceLink ChannelLink,
bobChannel *lnwallet.LightningChannel, expHtlcs int) { bobChannel *lnwallet.LightningChannel, expHtlcs int) {
t.Helper()
sig, htlcSigs, err := bobChannel.SignNextCommitment() sig, htlcSigs, err := bobChannel.SignNextCommitment()
if err != nil { if err != nil {
t.Fatalf("error signing commitment: %v", err) t.Fatalf("error signing commitment: %v", err)
@ -4186,6 +4266,9 @@ func sendCommitSigBobToAlice(t *testing.T, aliceLink ChannelLink,
func receiveRevAndAckAliceToBob(t *testing.T, aliceMsgs chan lnwire.Message, func receiveRevAndAckAliceToBob(t *testing.T, aliceMsgs chan lnwire.Message,
aliceLink ChannelLink, aliceLink ChannelLink,
bobChannel *lnwallet.LightningChannel) { bobChannel *lnwallet.LightningChannel) {
t.Helper()
var msg lnwire.Message var msg lnwire.Message
select { select {
case msg = <-aliceMsgs: case msg = <-aliceMsgs:
@ -4239,6 +4322,9 @@ func receiveCommitSigAliceToBob(t *testing.T, aliceMsgs chan lnwire.Message,
// the RevokeAndAck to Alice. // the RevokeAndAck to Alice.
func sendRevAndAckBobToAlice(t *testing.T, aliceLink ChannelLink, func sendRevAndAckBobToAlice(t *testing.T, aliceLink ChannelLink,
bobChannel *lnwallet.LightningChannel) { bobChannel *lnwallet.LightningChannel) {
t.Helper()
rev, _, err := bobChannel.RevokeCurrentCommitment() rev, _, err := bobChannel.RevokeCurrentCommitment()
if err != nil { if err != nil {
t.Fatalf("unable to revoke commitment: %v", err) t.Fatalf("unable to revoke commitment: %v", err)
@ -4273,6 +4359,28 @@ func receiveSettleAliceToBob(t *testing.T, aliceMsgs chan lnwire.Message,
} }
} }
// sendSettleBobToAlice settles an HTLC on Bob's state machine, then sends an
// UpdateFulfillHTLC message to Alice's upstream inbox.
func sendSettleBobToAlice(t *testing.T, aliceLink ChannelLink,
bobChannel *lnwallet.LightningChannel, htlcID uint64,
preimage lntypes.Preimage) {
t.Helper()
err := bobChannel.SettleHTLC(preimage, htlcID, nil, nil, nil)
if err != nil {
t.Fatalf("alice failed settling htlc id=%d hash=%x",
htlcID, sha256.Sum256(preimage[:]))
}
settle := &lnwire.UpdateFulfillHTLC{
ID: htlcID,
PaymentPreimage: preimage,
}
aliceLink.HandleChannelUpdate(settle)
}
// receiveSettleAliceToBob waits for Alice to send a HTLC settle message to // receiveSettleAliceToBob waits for Alice to send a HTLC settle message to
// Bob, then hands this to Bob. // Bob, then hands this to Bob.
func receiveFailAliceToBob(t *testing.T, aliceMsgs chan lnwire.Message, func receiveFailAliceToBob(t *testing.T, aliceMsgs chan lnwire.Message,
@ -4389,6 +4497,26 @@ func TestChannelLinkNoMoreUpdates(t *testing.T) {
} }
} }
// checkHasPreimages inspects Alice's preimage cache, and asserts whether the
// preimages for the provided HTLCs are known and unknown, and that all of them
// match the expected status of expOk.
func checkHasPreimages(t *testing.T, coreLink *channelLink,
htlcs []*lnwire.UpdateAddHTLC, expOk bool) {
t.Helper()
for i := range htlcs {
_, ok := coreLink.cfg.PreimageCache.LookupPreimage(
htlcs[i].PaymentHash,
)
if ok != expOk {
t.Fatalf("expected to find witness: %v, "+
"got %v for hash=%x", expOk, ok,
htlcs[i].PaymentHash)
}
}
}
// TestChannelLinkWaitForRevocation tests that we will keep accepting updates // TestChannelLinkWaitForRevocation tests that we will keep accepting updates
// to our commitment transaction, even when we are waiting for a revocation // to our commitment transaction, even when we are waiting for a revocation
// from the remote node. // from the remote node.
@ -4500,6 +4628,135 @@ func TestChannelLinkWaitForRevocation(t *testing.T) {
} }
} }
// TestChannelLinkBatchPreimageWrite asserts that a link will batch preimage
// writes when just as it receives a CommitSig to lock in any Settles, and also
// if the link is aware of any uncommitted preimages if the link is stopped,
// i.e. due to a disconnection or shutdown.
func TestChannelLinkBatchPreimageWrite(t *testing.T) {
t.Parallel()
tests := []struct {
name string
disconnect bool
}{
{
name: "flush on commit sig",
disconnect: false,
},
{
name: "flush on disconnect",
disconnect: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
testChannelLinkBatchPreimageWrite(t, test.disconnect)
})
}
}
func testChannelLinkBatchPreimageWrite(t *testing.T, disconnect bool) {
const chanAmt = btcutil.SatoshiPerBitcoin * 5
const chanReserve = btcutil.SatoshiPerBitcoin * 1
aliceLink, bobChannel, batchTicker, startUp, cleanUp, _, err :=
newSingleLinkTestHarness(chanAmt, chanReserve)
if err != nil {
t.Fatalf("unable to create link: %v", err)
}
defer cleanUp()
if err := startUp(); err != nil {
t.Fatalf("unable to start test harness: %v", err)
}
var (
coreLink = aliceLink.(*channelLink)
aliceMsgs = coreLink.cfg.Peer.(*mockPeer).sentMsgs
)
// We will send 10 HTLCs in total, from Bob to Alice.
numHtlcs := 10
var htlcs []*lnwire.UpdateAddHTLC
var invoices []*channeldb.Invoice
for i := 0; i < numHtlcs; i++ {
htlc, invoice := generateHtlcAndInvoice(t, uint64(i))
htlcs = append(htlcs, htlc)
invoices = append(invoices, invoice)
}
// First, send a batch of Adds from Alice to Bob.
for i, htlc := range htlcs {
sendHtlcAliceToBob(t, aliceLink, i, htlc)
receiveHtlcAliceToBob(t, aliceMsgs, bobChannel)
}
// Assert that no preimages exist for these htlcs in Alice's cache.
checkHasPreimages(t, coreLink, htlcs, false)
// Force alice's link to sign a commitment covering the htlcs sent thus
// far.
select {
case batchTicker <- time.Now():
case <-time.After(15 * time.Second):
t.Fatalf("could not force commit sig")
}
// Do a commitment dance to lock in the Adds, we expect numHtlcs htlcs
// to be on each party's commitment transactions.
receiveCommitSigAliceToBob(
t, aliceMsgs, aliceLink, bobChannel, numHtlcs,
)
sendRevAndAckBobToAlice(t, aliceLink, bobChannel)
sendCommitSigBobToAlice(t, aliceLink, bobChannel, numHtlcs)
receiveRevAndAckAliceToBob(t, aliceMsgs, aliceLink, bobChannel)
// Check again that no preimages exist for these htlcs in Alice's cache.
checkHasPreimages(t, coreLink, htlcs, false)
// Now, have Bob settle the HTLCs back to Alice using the preimages in
// the invoice corresponding to each of the HTLCs.
for i, invoice := range invoices {
sendSettleBobToAlice(
t, aliceLink, bobChannel, uint64(i),
invoice.Terms.PaymentPreimage,
)
}
// Assert that Alice has not yet written the preimages, even though she
// has received them in the UpdateFulfillHTLC messages.
checkHasPreimages(t, coreLink, htlcs, false)
// If this is the disconnect run, we will having Bob send Alice his
// CommitSig, and simply stop Alice's link. As she exits, we should
// detect that she has uncommitted preimages and write them to disk.
if disconnect {
aliceLink.Stop()
checkHasPreimages(t, coreLink, htlcs, true)
return
}
// Otherwise, we are testing that Alice commits the preimages after
// receiving a CommitSig from Bob. Bob's commitment should now have 0
// HTLCs.
sendCommitSigBobToAlice(t, aliceLink, bobChannel, 0)
// Since Alice will process the CommitSig asynchronously, we wait until
// she replies with her RevokeAndAck to ensure the tests reliably
// inspect her cache after advancing her state.
select {
// Received Alice's RevokeAndAck, assert that she has written all of the
// uncommitted preimages learned in this commitment.
case <-aliceMsgs:
checkHasPreimages(t, coreLink, htlcs, true)
// Alice didn't send her RevokeAndAck, something is wrong.
case <-time.After(15 * time.Second):
t.Fatalf("alice did not send her revocation")
}
}
// TestChannelLinkCleanupSpuriousResponses tests that we properly cleanup // TestChannelLinkCleanupSpuriousResponses tests that we properly cleanup
// references in the event that internal retransmission continues as a result of // references in the event that internal retransmission continues as a result of
// not properly cleaning up Add/SettleFailRefs. // not properly cleaning up Add/SettleFailRefs.

@ -32,25 +32,32 @@ import (
type mockPreimageCache struct { type mockPreimageCache struct {
sync.Mutex sync.Mutex
preimageMap map[[32]byte][]byte preimageMap map[lntypes.Hash]lntypes.Preimage
} }
func (m *mockPreimageCache) LookupPreimage(hash []byte) ([]byte, bool) { func newMockPreimageCache() *mockPreimageCache {
return &mockPreimageCache{
preimageMap: make(map[lntypes.Hash]lntypes.Preimage),
}
}
func (m *mockPreimageCache) LookupPreimage(
hash lntypes.Hash) (lntypes.Preimage, bool) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
var h [32]byte p, ok := m.preimageMap[hash]
copy(h[:], hash)
p, ok := m.preimageMap[h]
return p, ok return p, ok
} }
func (m *mockPreimageCache) AddPreimage(preimage []byte) error { func (m *mockPreimageCache) AddPreimages(preimages ...lntypes.Preimage) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
m.preimageMap[sha256.Sum256(preimage[:])] = preimage for _, preimage := range preimages {
m.preimageMap[preimage.Hash()] = preimage
}
return nil return nil
} }

@ -367,10 +367,7 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
aliceSigner := &mockSigner{aliceKeyPriv} aliceSigner := &mockSigner{aliceKeyPriv}
bobSigner := &mockSigner{bobKeyPriv} bobSigner := &mockSigner{bobKeyPriv}
pCache := &mockPreimageCache{ pCache := newMockPreimageCache()
// hash -> preimage
preimageMap: make(map[[32]byte][]byte),
}
alicePool := lnwallet.NewSigPool(runtime.NumCPU(), aliceSigner) alicePool := lnwallet.NewSigPool(runtime.NumCPU(), aliceSigner)
channelAlice, err := lnwallet.NewLightningChannel( channelAlice, err := lnwallet.NewLightningChannel(
@ -982,10 +979,7 @@ type hopNetwork struct {
func newHopNetwork() *hopNetwork { func newHopNetwork() *hopNetwork {
defaultDelta := uint32(6) defaultDelta := uint32(6)
pCache := &mockPreimageCache{ pCache := newMockPreimageCache()
// hash -> preimage
preimageMap: make(map[[32]byte][]byte),
}
globalPolicy := ForwardingPolicy{ globalPolicy := ForwardingPolicy{
MinHTLC: lnwire.NewMSatFromSatoshis(5), MinHTLC: lnwire.NewMSatFromSatoshis(5),

@ -23,7 +23,7 @@ var (
// All nodes initialized with the flag active will immediately settle // All nodes initialized with the flag active will immediately settle
// any incoming HTLC whose rHash corresponds with the debug // any incoming HTLC whose rHash corresponds with the debug
// preimage. // preimage.
DebugPre, _ = lntypes.NewPreimage(bytes.Repeat([]byte{1}, 32)) DebugPre, _ = lntypes.MakePreimage(bytes.Repeat([]byte{1}, 32))
// DebugHash is the hash of the default preimage. // DebugHash is the hash of the default preimage.
DebugHash = DebugPre.Hash() DebugHash = DebugPre.Hash()

@ -9,8 +9,8 @@ import (
// PreimageSize of array used to store preimagees. // PreimageSize of array used to store preimagees.
const PreimageSize = 32 const PreimageSize = 32
// Preimage is used in several of the lightning messages and common structures. It // Preimage is used in several of the lightning messages and common structures.
// represents a payment preimage. // It represents a payment preimage.
type Preimage [PreimageSize]byte type Preimage [PreimageSize]byte
// String returns the Preimage as a hexadecimal string. // String returns the Preimage as a hexadecimal string.
@ -18,35 +18,35 @@ func (p Preimage) String() string {
return hex.EncodeToString(p[:]) return hex.EncodeToString(p[:])
} }
// NewPreimage returns a new Preimage from a byte slice. An error is returned if // MakePreimage returns a new Preimage from a bytes slice. An error is returned
// the number of bytes passed in is not PreimageSize. // if the number of bytes passed in is not PreimageSize.
func NewPreimage(newPreimage []byte) (*Preimage, error) { func MakePreimage(newPreimage []byte) (Preimage, error) {
nhlen := len(newPreimage) nhlen := len(newPreimage)
if nhlen != PreimageSize { if nhlen != PreimageSize {
return nil, fmt.Errorf("invalid preimage length of %v, want %v", return Preimage{}, fmt.Errorf("invalid preimage length of %v, "+
nhlen, PreimageSize) "want %v", nhlen, PreimageSize)
} }
var preimage Preimage var preimage Preimage
copy(preimage[:], newPreimage) copy(preimage[:], newPreimage)
return &preimage, nil return preimage, nil
} }
// NewPreimageFromStr creates a Preimage from a hex preimage string. // MakePreimageFromStr creates a Preimage from a hex preimage string.
func NewPreimageFromStr(newPreimage string) (*Preimage, error) { func MakePreimageFromStr(newPreimage string) (Preimage, error) {
// Return error if preimage string is of incorrect length. // Return error if preimage string is of incorrect length.
if len(newPreimage) != PreimageSize*2 { if len(newPreimage) != PreimageSize*2 {
return nil, fmt.Errorf("invalid preimage string length of %v, "+ return Preimage{}, fmt.Errorf("invalid preimage string length "+
"want %v", len(newPreimage), PreimageSize*2) "of %v, want %v", len(newPreimage), PreimageSize*2)
} }
preimage, err := hex.DecodeString(newPreimage) preimage, err := hex.DecodeString(newPreimage)
if err != nil { if err != nil {
return nil, err return Preimage{}, err
} }
return NewPreimage(preimage) return MakePreimage(preimage)
} }
// Hash returns the sha256 hash of the preimage. // Hash returns the sha256 hash of the preimage.

@ -5577,12 +5577,12 @@ func extractHtlcResolutions(feePerKw SatPerKWeight, ourCommit bool,
// We'll now query the preimage cache for the preimage // We'll now query the preimage cache for the preimage
// for this HTLC. If it's present then we can fully // for this HTLC. If it's present then we can fully
// populate this resolution. // populate this resolution.
preimage, _ := pCache.LookupPreimage(htlc.RHash[:]) preimage, _ := pCache.LookupPreimage(htlc.RHash)
// Otherwise, we'll create an incoming HTLC resolution // Otherwise, we'll create an incoming HTLC resolution
// as we can satisfy the contract. // as we can satisfy the contract.
var pre [32]byte var pre [32]byte
copy(pre[:], preimage) copy(pre[:], preimage[:])
ihr, err := newIncomingHtlcResolution( ihr, err := newIncomingHtlcResolution(
signer, localChanCfg, commitHash, &htlc, keyRing, signer, localChanCfg, commitHash, &htlc, keyRing,
feePerKw, dustLimit, uint32(csvDelay), ourCommit, feePerKw, dustLimit, uint32(csvDelay), ourCommit,

@ -19,6 +19,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -584,7 +585,7 @@ func TestForceClose(t *testing.T) {
// Before we force close Alice's channel, we'll add the pre-image of // Before we force close Alice's channel, we'll add the pre-image of
// Bob's HTLC to her preimage cache. // Bob's HTLC to her preimage cache.
aliceChannel.pCache.AddPreimage(preimageBob[:]) aliceChannel.pCache.AddPreimages(lntypes.Preimage(preimageBob))
// With the cache populated, we'll now attempt the force close // With the cache populated, we'll now attempt the force close
// initiated by Alice. // initiated by Alice.
@ -4953,7 +4954,7 @@ func TestChannelUnilateralCloseHtlcResolution(t *testing.T) {
// Now that Bob has force closed, we'll modify Alice's pre image cache // Now that Bob has force closed, we'll modify Alice's pre image cache
// such that she now gains the ability to also settle the incoming HTLC // such that she now gains the ability to also settle the incoming HTLC
// from Bob. // from Bob.
aliceChannel.pCache.AddPreimage(preimageBob[:]) aliceChannel.pCache.AddPreimages(lntypes.Preimage(preimageBob))
// We'll then use Bob's transaction to trigger a spend notification for // We'll then use Bob's transaction to trigger a spend notification for
// Alice. // Alice.

@ -9,6 +9,7 @@ import (
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lntypes"
) )
// AddressType is an enum-like type which denotes the possible address types // AddressType is an enum-like type which denotes the possible address types
@ -272,11 +273,12 @@ type PreimageCache interface {
// LookupPreimage attempts to look up a preimage according to its hash. // LookupPreimage attempts to look up a preimage according to its hash.
// If found, the preimage is returned along with true for the second // If found, the preimage is returned along with true for the second
// argument. Otherwise, it'll return false. // argument. Otherwise, it'll return false.
LookupPreimage(hash []byte) ([]byte, bool) LookupPreimage(hash lntypes.Hash) (lntypes.Preimage, bool)
// AddPreimage attempts to add a new preimage to the global cache. If // AddPreimages adds a batch of newly discovered preimages to the global
// successful a nil error will be returned. // cache, and also signals any subscribers of the newly discovered
AddPreimage(preimage []byte) error // witness.
AddPreimages(preimages ...lntypes.Preimage) error
} }
// WalletDriver represents a "driver" for a particular concrete // WalletDriver represents a "driver" for a particular concrete

@ -3,7 +3,6 @@ package lnwallet
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"io" "io"
@ -18,6 +17,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/shachain" "github.com/lightningnetwork/lnd/shachain"
) )
@ -301,10 +301,7 @@ func CreateTestChannels() (*LightningChannel, *LightningChannel, func(), error)
aliceSigner := &input.MockSigner{Privkeys: aliceKeys} aliceSigner := &input.MockSigner{Privkeys: aliceKeys}
bobSigner := &input.MockSigner{Privkeys: bobKeys} bobSigner := &input.MockSigner{Privkeys: bobKeys}
pCache := &mockPreimageCache{ pCache := newMockPreimageCache()
// hash -> preimage
preimageMap: make(map[[32]byte][]byte),
}
// TODO(roasbeef): make mock version of pre-image store // TODO(roasbeef): make mock version of pre-image store
@ -389,25 +386,37 @@ func initRevocationWindows(chanA, chanB *LightningChannel) error {
type mockPreimageCache struct { type mockPreimageCache struct {
sync.Mutex sync.Mutex
preimageMap map[[32]byte][]byte preimageMap map[lntypes.Hash]lntypes.Preimage
} }
func (m *mockPreimageCache) LookupPreimage(hash []byte) ([]byte, bool) { func newMockPreimageCache() *mockPreimageCache {
return &mockPreimageCache{
preimageMap: make(map[lntypes.Hash]lntypes.Preimage),
}
}
func (m *mockPreimageCache) LookupPreimage(
hash lntypes.Hash) (lntypes.Preimage, bool) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
var h [32]byte p, ok := m.preimageMap[hash]
copy(h[:], hash)
p, ok := m.preimageMap[h]
return p, ok return p, ok
} }
func (m *mockPreimageCache) AddPreimage(preimage []byte) error { func (m *mockPreimageCache) AddPreimages(preimages ...lntypes.Preimage) error {
preimageCopies := make([]lntypes.Preimage, 0, len(preimages))
for _, preimage := range preimages {
preimageCopies = append(preimageCopies, preimage)
}
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
m.preimageMap[sha256.Sum256(preimage[:])] = preimage for _, preimage := range preimageCopies {
m.preimageMap[preimage.Hash()] = preimage
}
return nil return nil
} }

@ -780,10 +780,7 @@ func TestCommitmentAndHTLCTransactions(t *testing.T) {
}, },
} }
pCache := &mockPreimageCache{ pCache := newMockPreimageCache()
// hash -> preimage
preimageMap: make(map[[32]byte][]byte),
}
for i, test := range testCases { for i, test := range testCases {
expectedCommitmentTx, err := txFromHex(test.expectedCommitmentTxHex) expectedCommitmentTx, err := txFromHex(test.expectedCommitmentTxHex)

23
mock.go

@ -1,7 +1,6 @@
package main package main
import ( import (
"crypto/sha256"
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -16,6 +15,7 @@ import (
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
) )
@ -303,25 +303,30 @@ func (m *mockSecretKeyRing) ScalarMult(keyDesc keychain.KeyDescriptor,
type mockPreimageCache struct { type mockPreimageCache struct {
sync.Mutex sync.Mutex
preimageMap map[[32]byte][]byte preimageMap map[lntypes.Hash]lntypes.Preimage
} }
func (m *mockPreimageCache) LookupPreimage(hash []byte) ([]byte, bool) { func newMockPreimageCache() *mockPreimageCache {
return &mockPreimageCache{
preimageMap: make(map[lntypes.Hash]lntypes.Preimage),
}
}
func (m *mockPreimageCache) LookupPreimage(hash lntypes.Hash) (lntypes.Preimage, bool) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
var h [32]byte p, ok := m.preimageMap[hash]
copy(h[:], hash)
p, ok := m.preimageMap[h]
return p, ok return p, ok
} }
func (m *mockPreimageCache) AddPreimage(preimage []byte) error { func (m *mockPreimageCache) AddPreimages(preimages ...lntypes.Preimage) error {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
m.preimageMap[sha256.Sum256(preimage[:])] = preimage for _, preimage := range preimages {
m.preimageMap[preimage.Hash()] = preimage
}
return nil return nil
} }

@ -316,7 +316,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB, cc *chainControl,
// HTLCs with the debug R-Hash immediately settled. // HTLCs with the debug R-Hash immediately settled.
if cfg.DebugHTLC { if cfg.DebugHTLC {
kiloCoin := btcutil.Amount(btcutil.SatoshiPerBitcoin * 1000) kiloCoin := btcutil.Amount(btcutil.SatoshiPerBitcoin * 1000)
s.invoices.AddDebugInvoice(kiloCoin, *invoices.DebugPre) s.invoices.AddDebugInvoice(kiloCoin, invoices.DebugPre)
srvrLog.Debugf("Debug HTLC invoice inserted, preimage=%x, hash=%x", srvrLog.Debugf("Debug HTLC invoice inserted, preimage=%x, hash=%x",
invoices.DebugPre[:], invoices.DebugHash[:]) invoices.DebugPre[:], invoices.DebugHash[:])
} }

@ -13,7 +13,7 @@ import (
// preimageSubscriber reprints an active subscription to be notified once the // preimageSubscriber reprints an active subscription to be notified once the
// daemon discovers new preimages, either on chain or off-chain. // daemon discovers new preimages, either on chain or off-chain.
type preimageSubscriber struct { type preimageSubscriber struct {
updateChan chan []byte updateChan chan lntypes.Preimage
quit chan struct{} quit chan struct{}
} }
@ -40,7 +40,7 @@ func (p *preimageBeacon) SubscribeUpdates() *contractcourt.WitnessSubscription {
clientID := p.clientCounter clientID := p.clientCounter
client := &preimageSubscriber{ client := &preimageSubscriber{
updateChan: make(chan []byte, 10), updateChan: make(chan lntypes.Preimage, 10),
quit: make(chan struct{}), quit: make(chan struct{}),
} }
@ -66,63 +66,74 @@ func (p *preimageBeacon) SubscribeUpdates() *contractcourt.WitnessSubscription {
// LookupPreImage attempts to lookup a preimage in the global cache. True is // LookupPreImage attempts to lookup a preimage in the global cache. True is
// returned for the second argument if the preimage is found. // returned for the second argument if the preimage is found.
func (p *preimageBeacon) LookupPreimage(payHash []byte) ([]byte, bool) { func (p *preimageBeacon) LookupPreimage(
payHash lntypes.Hash) (lntypes.Preimage, bool) {
p.RLock() p.RLock()
defer p.RUnlock() defer p.RUnlock()
// First, we'll check the invoice registry to see if we already know of // First, we'll check the invoice registry to see if we already know of
// the preimage as it's on that we created ourselves. // the preimage as it's on that we created ourselves.
var invoiceKey lntypes.Hash invoice, _, err := p.invoices.LookupInvoice(payHash)
copy(invoiceKey[:], payHash)
invoice, _, err := p.invoices.LookupInvoice(invoiceKey)
switch { switch {
case err == channeldb.ErrInvoiceNotFound: case err == channeldb.ErrInvoiceNotFound:
// If we get this error, then it simply means that this invoice // If we get this error, then it simply means that this invoice
// wasn't found, so we don't treat it as a critical error. // wasn't found, so we don't treat it as a critical error.
case err != nil: case err != nil:
return nil, false return lntypes.Preimage{}, false
} }
// If we've found the invoice, then we can return the preimage // If we've found the invoice, then we can return the preimage
// directly. // directly.
if err != channeldb.ErrInvoiceNotFound { if err != channeldb.ErrInvoiceNotFound {
return invoice.Terms.PaymentPreimage[:], true return invoice.Terms.PaymentPreimage, true
} }
// Otherwise, we'll perform a final check using the witness cache. // Otherwise, we'll perform a final check using the witness cache.
preimage, err := p.wCache.LookupWitness( preimage, err := p.wCache.LookupSha256Witness(payHash)
channeldb.Sha256HashWitness, payHash,
)
if err != nil { if err != nil {
ltndLog.Errorf("unable to lookup witness: %v", err) ltndLog.Errorf("Unable to lookup witness: %v", err)
return nil, false return lntypes.Preimage{}, false
} }
return preimage, true return preimage, true
} }
// AddPreImage adds a newly discovered preimage to the global cache, and also // AddPreimages adds a batch of newly discovered preimages to the global cache,
// signals any subscribers of the newly discovered witness. // and also signals any subscribers of the newly discovered witness.
func (p *preimageBeacon) AddPreimage(pre []byte) error { func (p *preimageBeacon) AddPreimages(preimages ...lntypes.Preimage) error {
p.Lock() // Exit early if no preimages are presented.
defer p.Unlock() if len(preimages) == 0 {
return nil
}
srvrLog.Infof("Adding preimage=%x to witness cache", pre[:]) // Copy the preimages to ensure the backing area can't be modified by
// the caller when delivering notifications.
preimageCopies := make([]lntypes.Preimage, 0, len(preimages))
for _, preimage := range preimages {
srvrLog.Infof("Adding preimage=%v to witness cache", preimage)
preimageCopies = append(preimageCopies, preimage)
}
// First, we'll add the witness to the decaying witness cache. // First, we'll add the witness to the decaying witness cache.
err := p.wCache.AddWitness(channeldb.Sha256HashWitness, pre) err := p.wCache.AddSha256Witnesses(preimages...)
if err != nil { if err != nil {
return err return err
} }
p.Lock()
defer p.Unlock()
// With the preimage added to our state, we'll now send a new // With the preimage added to our state, we'll now send a new
// notification to all subscribers. // notification to all subscribers.
for _, client := range p.subscribers { for _, client := range p.subscribers {
go func(c *preimageSubscriber) { go func(c *preimageSubscriber) {
select { for _, preimage := range preimageCopies {
case c.updateChan <- pre: select {
case <-c.quit: case c.updateChan <- preimage:
return case <-c.quit:
return
}
} }
}(client) }(client)
} }